# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import unittest
from types import ModuleType
from typing import Any, Callable, Optional, Tuple

import torch
from executorch.exir import ExecutorchBackendConfig, ExecutorchProgramManager, to_edge
from executorch.exir.passes import MemoryPlanningPass
from torch.export import export


class ModuleAdd(torch.nn.Module):
    """The module to serialize and execute."""

    def __init__(self):
        super(ModuleAdd, self).__init__()

    def forward(self, x, y):
        return x + y

    def get_methods_to_export(self):
        return ("forward",)

    def get_inputs(self):
        return (torch.ones(2, 2), torch.ones(2, 2))


class ModuleMulti(torch.nn.Module):
    """The module to serialize and execute."""

    def __init__(self):
        super(ModuleMulti, self).__init__()

    def forward(self, x, y):
        return x + y

    def forward2(self, x, y):
        return x + y + 1

    def get_methods_to_export(self):
        return ("forward", "forward2")

    def get_inputs(self):
        return (torch.ones(2, 2), torch.ones(2, 2))


class ModuleAddSingleInput(torch.nn.Module):
    """The module to serialize and execute."""

    def __init__(self):
        super(ModuleAddSingleInput, self).__init__()

    def forward(self, x):
        return x + x

    def get_methods_to_export(self):
        return ("forward",)

    def get_inputs(self):
        return (torch.ones(2, 2),)


class ModuleAddConstReturn(torch.nn.Module):
    """The module to serialize and execute."""

    def __init__(self):
        super(ModuleAddConstReturn, self).__init__()
        self.state = torch.ones(2, 2)

    def forward(self, x):
        return x + self.state, self.state

    def get_methods_to_export(self):
        return ("forward",)

    def get_inputs(self):
        return (torch.ones(2, 2),)


def create_program(
    eager_module: torch.nn.Module,
    et_config: Optional[ExecutorchBackendConfig] = None,
) -> Tuple[ExecutorchProgramManager, Tuple[Any, ...]]:
    """Returns an executorch program based on ModuleAdd, along with inputs."""

    # Trace the test module and create a serialized ExecuTorch program.
    # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
    #  is not a function.
    inputs = eager_module.get_inputs()
    input_map = {}
    # pyre-fixme[29]: `Union[torch._tensor.Tensor, torch.nn.modules.module.Module]`
    #  is not a function.
    for method in eager_module.get_methods_to_export():
        input_map[method] = inputs

    class WrapperModule(torch.nn.Module):
        def __init__(self, fn):
            super().__init__()
            self.fn = fn

        def forward(self, *args, **kwargs):
            return self.fn(*args, **kwargs)

    exported_methods = {}
    # These cleanup passes are required to convert the `add` op to its out
    # variant, along with some other transformations.
    for method_name, method_input in input_map.items():
        wrapped_mod = WrapperModule(getattr(eager_module, method_name))
        exported_methods[method_name] = export(wrapped_mod, method_input)

    exec_prog = to_edge(exported_methods).to_executorch(config=et_config)

    # Create the ExecuTorch program from the graph.
    exec_prog.dump_executorch_program(verbose=True)
    return (exec_prog, inputs)


def make_test(  # noqa: C901
    tester: unittest.TestCase,
    runtime: ModuleType,
) -> Callable[[unittest.TestCase], None]:
    """
    Returns a function that operates as a test case within a unittest.TestCase class.

    Used to allow the test code for pybindings to be shared across different pybinding libs
    which will all have different load functions. In this case each individual test case is a
    subfunction of wrapper.
    """
    load_fn: Callable = runtime._load_for_executorch_from_buffer

    def wrapper(tester: unittest.TestCase) -> None:

        ######### TEST CASES #########

        def test_e2e(tester):
            # Create an ExecuTorch program from ModuleAdd.
            exported_program, inputs = create_program(ModuleAdd())

            # Use pybindings to load and execute the program.
            executorch_module = load_fn(exported_program.buffer)
            executorch_output = executorch_module.forward(inputs)[0]

            # The test module adds the two inputs, so its output should be the same
            # as adding them directly.
            expected = inputs[0] + inputs[1]

            tester.assertEqual(str(expected), str(executorch_output))

        def test_multiple_entry(tester):

            program, inputs = create_program(ModuleMulti())
            executorch_module = load_fn(program.buffer)

            executorch_output = executorch_module.forward(inputs)[0]
            tester.assertTrue(torch.allclose(executorch_output, torch.ones(2, 2) * 2))

            executorch_output2 = executorch_module.run_method("forward2", inputs)[0]
            tester.assertTrue(torch.allclose(executorch_output2, torch.ones(2, 2) * 3))

        def test_output_lifespan(tester):
            def lower_function_call():
                program, inputs = create_program(ModuleMulti())
                executorch_module = load_fn(program.buffer)

                return executorch_module.forward(inputs)
                # executorch_module is destructed here and all of its memory is freed

            outputs = lower_function_call()
            tester.assertTrue(torch.allclose(outputs[0], torch.ones(2, 2) * 2))

        def test_module_callable(tester):
            # Create an ExecuTorch program from ModuleAdd.
            exported_program, inputs = create_program(ModuleAdd())

            # Use pybindings to load and execute the program.
            executorch_module = load_fn(exported_program.buffer)
            # Invoke the callable on executorch_module instead of calling module.forward.
            executorch_output = executorch_module(inputs)[0]

            # The test module adds the two inputs, so its output should be the same
            # as adding them directly.
            expected = inputs[0] + inputs[1]
            tester.assertEqual(str(expected), str(executorch_output))

        def test_module_single_input(tester):
            # Create an ExecuTorch program from ModuleAdd.
            exported_program, inputs = create_program(ModuleAddSingleInput())

            # Use pybindings to load and execute the program.
            executorch_module = load_fn(exported_program.buffer)
            # Inovke the callable on executorch_module instead of calling module.forward.
            # Use only one input to test this case.
            executorch_output = executorch_module(inputs[0])[0]

            # The test module adds the two inputs, so its output should be the same
            # as adding them directly.
            expected = inputs[0] + inputs[0]
            tester.assertEqual(str(expected), str(executorch_output))

        def test_stderr_redirect(tester):
            import sys
            from io import StringIO

            class RedirectedStderr:
                def __init__(self):
                    self._stderr = None
                    self._string_io = None

                def __enter__(self):
                    self._stderr = sys.stderr
                    sys.stderr = self._string_io = StringIO()
                    return self

                def __exit__(self, type, value, traceback):
                    sys.stderr = self._stderr

                def __str__(self):
                    return self._string_io.getvalue()

            with RedirectedStderr() as out:
                try:
                    # Create an ExecuTorch program from ModuleAdd.
                    exported_program, inputs = create_program(ModuleAdd())

                    # Use pybindings to load and execute the program.
                    executorch_module = load_fn(exported_program.buffer)

                    # add an extra input to trigger error
                    inputs = (*inputs, 1)

                    # Invoke the callable on executorch_module instead of calling module.forward.
                    executorch_output = executorch_module(inputs)[0]  # noqa
                    tester.assertFalse(True)  # should be unreachable
                except Exception:
                    tester.assertTrue(str(out).find("The length of given input array"))

        def test_quantized_ops(tester):
            eager_module = ModuleAdd()

            from executorch.exir import EdgeCompileConfig
            from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
            from executorch.kernels import quantized  # noqa: F401
            from torch.ao.quantization import get_default_qconfig_mapping
            from torch.ao.quantization.backend_config.executorch import (
                get_executorch_backend_config,
            )
            from torch.ao.quantization.quantize_fx import (
                _convert_to_reference_decomposed_fx,
                prepare_fx,
            )

            qconfig_mapping = get_default_qconfig_mapping("qnnpack")
            example_inputs = (
                torch.ones(1, 5, dtype=torch.float32),
                torch.ones(1, 5, dtype=torch.float32),
            )
            m = prepare_fx(
                eager_module,
                qconfig_mapping,
                example_inputs,
                backend_config=get_executorch_backend_config(),
            )
            m = _convert_to_reference_decomposed_fx(m)
            config = EdgeCompileConfig(_check_ir_validity=False)
            m = to_edge(export(m, example_inputs), compile_config=config)
            m = m.transform([QuantFusionPass(_fix_node_meta_val=True)])

            exec_prog = m.to_executorch()

            executorch_module = load_fn(exec_prog.buffer)
            executorch_output = executorch_module.forward(example_inputs)[0]

            expected = example_inputs[0] + example_inputs[1]
            tester.assertEqual(str(expected), str(executorch_output))

        def test_constant_output_not_memory_planned(tester):
            # Create an ExecuTorch program from ModuleAdd.
            exported_program, inputs = create_program(
                ModuleAddConstReturn(),
                et_config=ExecutorchBackendConfig(
                    memory_planning_pass=MemoryPlanningPass(alloc_graph_output=False)
                ),
            )

            exported_program.dump_executorch_program(verbose=True)

            # Use pybindings to load and execute the program.
            executorch_module = load_fn(exported_program.buffer)
            # Invoke the callable on executorch_module instead of calling module.forward.
            # Use only one input to test this case.
            executorch_output = executorch_module((torch.ones(2, 2),))
            print(executorch_output)

            # The test module adds the input to torch.ones(2,2), so its output should be the same
            # as adding them directly.
            expected = torch.ones(2, 2) + torch.ones(2, 2)
            tester.assertEqual(str(expected), str(executorch_output[0]))

            # The test module returns the state. Check that its value is correct.
            tester.assertEqual(str(torch.ones(2, 2)), str(executorch_output[1]))

        def test_method_meta(tester) -> None:
            exported_program, inputs = create_program(ModuleAdd())

            # Use pybindings to load the program and query its metadata.
            executorch_module = load_fn(exported_program.buffer)
            meta = executorch_module.method_meta("forward")

            # Ensure that all these APIs work even if the module object is destroyed.
            del executorch_module
            tester.assertEqual(meta.name(), "forward")
            tester.assertEqual(meta.num_inputs(), 2)
            tester.assertEqual(meta.num_outputs(), 1)
            # Common string for all these tensors.
            tensor_info = "TensorInfo(sizes=[2, 2], dtype=Float, is_memory_planned=True, nbytes=16)"
            float_dtype = 6
            tester.assertEqual(
                str(meta),
                "MethodMeta(name='forward', num_inputs=2, "
                f"input_tensor_meta=['{tensor_info}', '{tensor_info}'], "
                f"num_outputs=1, output_tensor_meta=['{tensor_info}'])",
            )

            input_tensors = [meta.input_tensor_meta(i) for i in range(2)]
            output_tensor = meta.output_tensor_meta(0)
            # Check that accessing out of bounds raises IndexError.
            with tester.assertRaises(IndexError):
                meta.input_tensor_meta(2)
            # Test that tensor metadata can outlive method metadata.
            del meta
            tester.assertEqual([t.sizes() for t in input_tensors], [(2, 2), (2, 2)])
            tester.assertEqual(
                [t.dtype() for t in input_tensors], [float_dtype, float_dtype]
            )
            tester.assertEqual(
                [t.is_memory_planned() for t in input_tensors], [True, True]
            )
            tester.assertEqual([t.nbytes() for t in input_tensors], [16, 16])
            tester.assertEqual(str(input_tensors), f"[{tensor_info}, {tensor_info}]")

            tester.assertEqual(output_tensor.sizes(), (2, 2))
            tester.assertEqual(output_tensor.dtype(), float_dtype)
            tester.assertEqual(output_tensor.is_memory_planned(), True)
            tester.assertEqual(output_tensor.nbytes(), 16)
            tester.assertEqual(str(output_tensor), tensor_info)

        def test_bad_name(tester) -> None:
            # Create an ExecuTorch program from ModuleAdd.
            exported_program, inputs = create_program(ModuleAdd())

            # Use pybindings to load and execute the program.
            executorch_module = load_fn(exported_program.buffer)
            # Invoke the callable on executorch_module instead of calling module.forward.
            with tester.assertRaises(RuntimeError):
                executorch_module.run_method("not_a_real_method", inputs)

        def test_verification_config(tester) -> None:
            # Create an ExecuTorch program from ModuleAdd.
            exported_program, inputs = create_program(ModuleAdd())
            Verification = runtime.Verification

            # Use pybindings to load and execute the program.
            for config in [Verification.Minimal, Verification.InternalConsistency]:
                executorch_module = load_fn(
                    exported_program.buffer,
                    enable_etdump=False,
                    debug_buffer_size=0,
                    program_verification=config,
                )

                executorch_output = executorch_module.forward(inputs)[0]

                # The test module adds the two inputs, so its output should be the same
                # as adding them directly.
                expected = inputs[0] + inputs[1]

                tester.assertEqual(str(expected), str(executorch_output))

        ######### RUN TEST CASES #########
        test_e2e(tester)
        test_multiple_entry(tester)
        test_output_lifespan(tester)
        test_module_callable(tester)
        test_module_single_input(tester)
        test_stderr_redirect(tester)
        test_quantized_ops(tester)
        test_constant_output_not_memory_planned(tester)
        test_method_meta(tester)
        test_bad_name(tester)
        test_verification_config(tester)

    return wrapper
