# Owner(s): ["oncall: export"]

import unittest
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Tuple

import torch
import torch.utils._pytree as pytree
from torch._dynamo.test_case import TestCase
from torch._export.converter import TS2EPConverter
from torch.export import ExportedProgram
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
from torch.testing._internal.torchbind_impls import (
    _empty_tensor_queue,
    init_torchbind_implementations,
)


requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")


class TestConverter(TestCase):
    def setUp(self):
        init_torchbind_implementations()

        @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
        class FakeTensorQueue:
            def __init__(self, queue):
                self.queue = queue

            @classmethod
            def __obj_unflatten__(cls, flattened_ctx):
                return cls(**dict(flattened_ctx))

            def push(self, x):
                self.queue.append(x)

            def pop(self):
                if self.is_empty():
                    return torch.empty([])
                return self.queue.pop(0)

            def size(self):
                return len(self.queue)

            def is_empty(self):
                return len(self.queue) == 0

            def float_size(self):
                return float(len(self.queue))

        self.torch_bind_ops = [
            torch.ops._TorchScriptTesting.queue_pop,
            torch.ops._TorchScriptTesting.queue_push,
            torch.ops._TorchScriptTesting.queue_size,
        ]

    def tearDown(self):
        torch._library.fake_class_registry.deregister_fake_class(
            "_TorchScriptTesting::_TensorQueue"
        )

    def _check_equal_ts_ep_converter(
        self,
        M,
        inp,
        option: Optional[List[str]] = None,
        check_persistent=False,
        lifted_tensor_constants=None,
    ) -> List[ExportedProgram]:
        # By default, it tests both jit.trace and jit.script.
        if option is None:
            option = ["trace", "script"]

        if check_persistent:
            num_iterations = 10
        else:
            num_iterations = 1

        ep_list = []
        for opt in option:
            if opt == "script":
                # Separate two models for testing non-functional effects
                if check_persistent:
                    original_ts_model = torch.jit.script(M())
                    ts_model = torch.jit.script(M())
                    eager_model = M()
                else:
                    original_ts_model = torch.jit.script(M)
                    ts_model = torch.jit.script(M)
                    eager_model = M
            elif opt == "trace":
                if check_persistent:
                    original_ts_model = torch.jit.trace(M(), inp)
                    ts_model = torch.jit.trace(M(), inp)
                    eager_model = M()
                else:
                    original_ts_model = torch.jit.trace(M, inp)
                    ts_model = torch.jit.trace(M, inp)
                    eager_model = M
            else:
                raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")

            converter = TS2EPConverter(ts_model, inp)
            ep = converter.convert()
            ep_list.append(ep)

            for _ in range(num_iterations):
                orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
                ep_out, _ = pytree.tree_flatten(ep.module()(*inp))

                # Check module.
                if isinstance(eager_model, torch.nn.Module):
                    expected_state_dict = OrderedDict()
                    expected_state_dict.update(ts_model.state_dict())
                    if lifted_tensor_constants:
                        expected_state_dict.update(lifted_tensor_constants)
                    self.assertEqual(
                        ep.state_dict.keys(),
                        expected_state_dict.keys(),
                    )

                # Check results
                self._check_tensor_list_equal(ep_out, orig_out)
        return ep_list

    def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
        self.assertEqual(len(xs), len(ys))
        for x, y in zip(xs, ys):
            if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
                self.assertEqual(x.shape, y.shape)
                self.assertTrue(torch.allclose(x, y))
            else:
                self.assertEqual(type(x), type(y))
                self.assertEqual(x, y)

    def test_ts2ep_converter_basic(self):
        class MSingle(torch.nn.Module):
            def forward(self, x, y):
                return x + y

        class MMulti(torch.nn.Module):
            def forward(self, x, y):
                x = x.cos() + 1
                y = y.sin() - 1
                return x, y

        inp = (torch.ones(1, 3), torch.ones(1, 3))
        self._check_equal_ts_ep_converter(MSingle(), inp)
        self._check_equal_ts_ep_converter(MMulti(), inp)

    def test_ts2ep_converter_container_output(self):
        # Output is a List.
        class MOutputList(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                a = x * x
                b = y + y
                return [a, b]

        # Output is a Tuple.
        class MOutputTuple(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                a = x * x
                b = y + y
                return (a, b)

        # Output is a Dict.
        class MOutputDict(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                a = x * x
                b = y + y
                return {"data": {"mul": a, "add": b}}

        inp = (torch.tensor(4), torch.tensor(4))

        # Traced function must use immutable structure as output.
        self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"])
        self._check_equal_ts_ep_converter(MOutputTuple(), inp)
        self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"])

    def test_aten_dim(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                num_dim = x.dim()
                return torch.ones(num_dim)

        inp = (torch.ones(1, 3),)
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_aten_len(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor):
                length = len(x)
                return torch.ones(length)

        # aten::len.Tensor
        inp = (torch.ones(2, 3),)
        self._check_equal_ts_ep_converter(Module(), inp)

        class Module(torch.nn.Module):
            def forward(self, x: List[int]):
                length = len(x)
                return torch.ones(length)

        # aten::len.t
        inp = ([1, 2, 3],)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        class Module(torch.nn.Module):
            def forward(self, x: Dict[int, str]):
                length = len(x)
                return torch.ones(length)

        # aten::len.Dict_int
        inp = ({1: "a", 2: "b", 3: "c"},)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        class Module(torch.nn.Module):
            def forward(self, x: Dict[bool, str]):
                length = len(x)
                return torch.ones(length)

        # aten::len.Dict_bool
        inp = ({True: "a", False: "b"},)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        class Module(torch.nn.Module):
            def forward(self, x: Dict[float, str]):
                length = len(x)
                return torch.ones(length)

        # aten::len.Dict_float
        inp = ({1.2: "a", 3.4: "b"},)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        class Module(torch.nn.Module):
            def forward(self, x: Dict[torch.Tensor, str]):
                length = len(x)
                return torch.ones(length)

        # aten::len.Dict_Tensor
        inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        # aten::len.str and aten::len.Dict_str are not supported
        # since torch._C._jit_flatten does not support str
        # inp = ("abcdefg",)
        # self._check_equal_ts_ep_converter(Module(), inp)
        # inp = ({"a": 1, "b": 2},)
        # self._check_equal_ts_ep_converter(Module(), inp)

    def test_aten_add_t(self):
        # python list append
        class Module(torch.nn.Module):
            def forward(self, x: List[torch.Tensor]):
                out = []
                out = out + x
                a = torch.cat(out)
                out = out + x
                b = torch.cat(out)
                return a, b

        inp = ([torch.ones(2, 3), torch.ones(2, 3)],)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

    def test_aten_to_dtype_with_mutating_storage(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                x = x.to(y.dtype)
                torch.ops.aten.index_put_(x, [torch.tensor([0])], y)
                return x

        inp = (torch.ones(2, 3), torch.tensor([0, 0, 0]))
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_prim_min(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                x_len = len(x)
                y_len = len(y)

                # prim::min.int
                len_int = min(x_len, y_len)

                # prim::min.float
                len_float = int(min(x_len * 2.0, y_len * 2.0))

                # prim::min.self_int
                len_self_int = min([x_len, y_len])

                # prim::min.self_float
                len_self_float = int(min([x_len * 2.0, y_len * 2.0]))

                # prim::min.float_int
                len_float_int = int(min(x_len * 2.0, y_len))

                # prim::min.int_float
                len_int_float = int(min(x_len, y_len * 2.0))

                return torch.ones(
                    len_int
                    + len_float
                    + len_self_int
                    + len_self_float
                    + len_float_int
                    + len_int_float
                )

        inp = (torch.randn(10, 2), torch.randn(5))
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_prim_max(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                x_len = len(x)
                y_len = len(y)

                # prim::max.int
                len_int = max(x_len, y_len)

                # prim::max.float
                len_float = int(max(x_len * 2.0, y_len * 2.0))

                # prim::max.self_int
                len_self_int = max([x_len, y_len])

                # prim::max.self_float
                len_self_float = int(max([x_len * 2.0, y_len * 2.0]))

                # prim::max.float_int
                len_float_int = int(max(x_len * 2.0, y_len))

                # prim::max.int_float
                len_int_float = int(max(x_len, y_len * 2.0))

                return torch.ones(
                    len_int
                    + len_float
                    + len_self_int
                    + len_self_float
                    + len_float_int
                    + len_int_float
                )

        inp = (torch.randn(10, 2), torch.randn(5))
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_aten___getitem___list(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                y = torch.split(x, 2)
                return y[0]

        inp = (torch.rand((3, 2)),)
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_aten___getitem___dict(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                y = torch.split(x, 2)
                d_int = {0: y[0], 1: y[1]}
                d_str = {"0": y[0], "1": y[1]}
                d_bool = {True: y[0], False: y[1]}
                d_float = {0.1: y[0], 2.3: y[1]}
                return d_int[0], d_str["0"], d_bool[True], d_float[0.1]

        inp = (torch.rand((3, 2)),)
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_prim_device(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                device = x.device
                return torch.ones(2, 3, device=device)

        inp = (torch.rand(3, 4),)
        self._check_equal_ts_ep_converter(Module(), inp)

    @requires_cuda
    def test_prim_device_cuda(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                device = x.device
                return torch.ones(2, 3, device=device)

        inp = (torch.rand((3, 4), device="cuda:0"),)
        self._check_equal_ts_ep_converter(Module(), inp)

    def test_prim_dtype(self):
        class Module(torch.nn.Module):
            def forward(self, x):
                dtype = x.dtype
                return torch.ones(2, 3, dtype=dtype)

        for dtype in [
            torch.float32,
            torch.double,
        ]:
            inp = (torch.rand((3, 4), dtype=dtype),)
            self._check_equal_ts_ep_converter(Module(), inp)

        for dtype in [
            torch.uint8,
            torch.int8,
            torch.int32,
        ]:
            inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
            self._check_equal_ts_ep_converter(Module(), inp)

    def test_convert_if_basic(self):
        class M(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor):
                if x:
                    return y * y
                else:
                    return y + y

        inp = (torch.tensor(True), torch.tensor(4))
        ep_list = self._check_equal_ts_ep_converter(M(), inp)

        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(torch.tensor(False), torch.tensor(4)),
                M()(torch.tensor(False), torch.tensor(4)),
            )

    def test_convert_if_tuple_out(self):
        class M(torch.nn.Module):
            def true_fn(self, y, z):
                return (z * z, z + z)

            def false_fn(self, y, z):
                return (y * y * y, y + y)

            def forward(self, x: torch.Tensor, y: torch.Tensor):
                z = y * y

                if x:
                    res = self.true_fn(y, z)
                else:
                    res = self.false_fn(y, z)

                return res[0] + res[1]

        inp = (torch.tensor(True), torch.tensor(4))
        ep_list = self._check_equal_ts_ep_converter(M(), inp)

        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(torch.tensor(False), torch.tensor(4)),
                M()(torch.tensor(False), torch.tensor(4)),
            )

    def test_convert_if_multiple_out(self):
        class M(torch.nn.Module):
            def true_fn(self, y, z):
                return z * z

            def false_fn(self, y, z):
                return y * y * y

            def forward(self, x: torch.Tensor, y: torch.Tensor):
                z = y * y

                if x:
                    res1 = self.true_fn(y, z)
                    res2 = y
                else:
                    res1 = z
                    res2 = self.false_fn(y, z)

                return res1 + res2

        inp = (torch.tensor(True), torch.tensor(4))
        ep_list = self._check_equal_ts_ep_converter(M(), inp)

        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(torch.tensor(False), torch.tensor(4)),
                M()(torch.tensor(False), torch.tensor(4)),
            )

    def test_profiler__record_function(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                handle = torch.ops.profiler._record_function_enter_new("foo", None)
                y = x * 2 + 4
                torch.ops.profiler._record_function_exit(handle)
                return y

        x = torch.randn(10, 10)
        self._check_equal_ts_ep_converter(Module(), (x,))

    def test_aten_floordiv(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x // 2

        x = torch.randn(10, 10)
        self._check_equal_ts_ep_converter(Module(), (x,))

    def test_aten___is__(self):
        class Module(torch.nn.Module):
            def forward(
                self, x: torch.Tensor, y: torch.Tensor
            ) -> Tuple[bool, torch.Tensor]:
                z = x + 1
                return x is y, z

        # Traced function must return output that has tensors.
        inp = (torch.randn(10, 10), torch.rand(10, 10))
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

    def test_aten___isnot__(self):
        class Module(torch.nn.Module):
            def forward(
                self, x: torch.Tensor, y: torch.Tensor
            ) -> Tuple[bool, torch.Tensor]:
                z = x + 1
                return x is not y, z

        # Traced function must return output that has tensors.
        inp = (torch.randn(10, 10), torch.rand(10, 10))
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

    def test_aten___not__(self):
        class Module(torch.nn.Module):
            def forward(
                self, x: torch.Tensor, y: torch.Tensor
            ) -> Tuple[bool, torch.Tensor]:
                z = x + 1
                return not (x is not y), z

        # Traced function must return output that has tensors.
        inp = (torch.randn(10, 10), torch.rand(10, 10))
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

    def test_ts2ep_converter_unpack(self):
        class MUnpackList(torch.nn.Module):
            def forward(self, x):
                x, y = torch.split(x, 2)
                return x + y

        class MUnpackTuple(torch.nn.Module):
            def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
                x, y = x_tuple
                x = x.cos()
                return x + y

        inp = (torch.ones(4),)
        self._check_equal_ts_ep_converter(MUnpackList(), inp)
        inp = ((torch.zeros(1, 4), torch.ones(1, 4)),)
        self._check_equal_ts_ep_converter(MUnpackTuple(), inp)

    @unittest.skipIf(
        IS_WINDOWS,
        "torch.cond doesn't go through torch.compile on windows"
        "causing output not normalized as list",
    )
    def test_convert_retrace_nested_scripted_modules(self):
        class Wrapper(torch.nn.Module):
            def __init__(self, mod) -> None:
                super().__init__()
                self.mod = mod

            def forward(self, x, y):
                return self.mod(x, y)

        class LinearM(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(dim, dim)

            def forward(self, x, y):
                return self.linear(y)

        class M(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                m = LinearM(dim)
                m = torch.jit.script(m)
                self.mod1 = m
                self.mod2 = Wrapper(m)

            def forward(self, x: torch.Tensor, y: torch.Tensor):
                if x:
                    return -self.mod1(x, y) - self.mod2(x, y)
                else:
                    return -self.mod1(x, y) + self.mod2(x, y)

        class NestedM(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                m = M(dim)
                m = torch.jit.script(m)
                self.mod1 = m
                self.mod2 = Wrapper(m)

            def forward(self, x: torch.Tensor, y: torch.Tensor):
                if x:
                    return self.mod1(x, y) + self.mod2(x, y)
                else:
                    return self.mod1(x, y) - self.mod2(x, y)

        inp = (
            torch.tensor(True),
            torch.randn([3, 3]),
        )
        self._check_equal_ts_ep_converter(NestedM(3), inp)

    def test_convert_nn_module_with_nested_param(self):
        class M(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(dim, dim)

            def forward(self, x: torch.Tensor):
                return self.linear(x)

        class NestedM(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(dim, dim)
                self.m = M(dim)

            def forward(self, x: torch.Tensor):
                return self.linear(self.m(x))

        class SuperNestedM(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(dim, dim)
                self.m = NestedM(dim)

            def forward(self, x: torch.Tensor):
                return self.linear(self.m(x))

        inp = (torch.ones(3),)
        orig_m = NestedM(3)
        self._check_equal_ts_ep_converter(orig_m, inp)
        orig_m = SuperNestedM(3)
        self._check_equal_ts_ep_converter(orig_m, inp)

    def test_convert_nn_module_with_nested_buffer(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.nn.Buffer(torch.randn(1))

            def forward(self, x: torch.Tensor):
                return self.w + x

        class NestedM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m = M()
                self.w = torch.nn.Buffer(torch.randn(1))

            def forward(self, x: torch.Tensor):
                return self.w + self.m(x)

        class SuperNestedM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m = NestedM()
                self.w = torch.nn.Buffer(torch.randn(1))

            def forward(self, x: torch.Tensor):
                return self.w + self.m(x)

        inp = (torch.ones(1),)
        orig_m = NestedM()
        self._check_equal_ts_ep_converter(orig_m, inp)
        orig_m = SuperNestedM()
        self._check_equal_ts_ep_converter(orig_m, inp)

    def test_convert_nn_module_with_nested_if_and_buffer(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w = torch.nn.Buffer(torch.randn(1))
                self.count = 1

            def forward(self, x: torch.Tensor):
                return self.w + x + self.count

        class NestedM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m1 = M()
                self.m2 = M()
                self.w = torch.nn.Buffer(torch.randn(1))

            def forward(self, x: torch.Tensor):
                if torch.sum(x) > 1:
                    return self.w + self.m1(x)
                else:
                    return self.w + self.m2(x)

        # Super nested, parameters neeed to lifted
        # multiple times.
        class SuperNestedM(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.m1 = NestedM()
                self.m2 = NestedM()
                self.w = torch.nn.Buffer(torch.randn(1))

            def forward(self, x: torch.Tensor):
                if torch.max(x) > 1:
                    return self.w + self.m1(x)
                else:
                    return self.w + self.m2(x)

        # Super nested module testing.
        inp = (torch.ones(1),)
        orig_m = SuperNestedM()
        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)

        t = inp[0]
        t -= 1
        for ep in ep_list:
            torch.testing.assert_close(
                ep.module()(*inp),
                orig_m(*inp),
            )

    @unittest.skipIf(
        IS_WINDOWS,
        "torch.cond doesn't go through torch.compile on windows"
        "causing output not normalized as list",
    )
    def test_convert_nn_module_with_nested_if_and_param(self):
        class M(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(dim, dim)

            def forward(self, x: torch.Tensor):
                return self.linear(x)

        class NestedM(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.m1 = M(dim)
                self.m2 = M(dim)
                self.linear = torch.nn.Linear(dim, dim)

            def forward(self, x: torch.Tensor):
                if torch.sum(x) > 1:
                    return self.linear(self.m1(x))
                else:
                    return self.linear(self.m2(x))

        # Super nested, parameters neeed to lifted
        # multiple times.
        class SuperNestedM1(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.m1 = NestedM(dim)
                self.m2 = NestedM(dim)
                self.linear = torch.nn.Linear(dim, dim)

            def forward(self, x: torch.Tensor):
                if torch.max(x) > 1:
                    return self.linear(self.m1(x))
                else:
                    return self.linear(self.m2(x))

        # Super nested, even the input needs to be
        # lifted recursively due to value propogation optimiztaion.
        class SuperNestedM2(torch.nn.Module):
            def __init__(self, dim: int) -> None:
                super().__init__()
                self.m1 = NestedM(dim)
                self.m2 = NestedM(dim)
                self.linear = torch.nn.Linear(dim, dim)

            def forward(self, x: torch.Tensor):
                if torch.sum(x) > 1:
                    return self.linear(self.m1(x))
                else:
                    return self.linear(self.m2(x))

        # Basic module testing.
        inp = (torch.ones(3),)
        orig_m = M(3)
        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)

        t = inp[0]
        t -= 0.8
        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(*inp),
                orig_m(*inp),
            )

        # Nested module testing.
        inp = (torch.ones(3),)
        orig_m = NestedM(3)
        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)

        t = inp[0]
        t -= 0.8
        # Skip jit.traced because it specializes on one path.
        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(*inp),
                orig_m(*inp),
            )

        # Super nested module testing.
        inp = (torch.ones(3),)
        orig_m = SuperNestedM1(3)
        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)

        t = inp[0]
        t -= 0.8
        # Skip jit.traced because it specializes on one path.
        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(*inp),
                orig_m(*inp),
            )

        # Super nested module testing.
        inp = (torch.ones(3),)
        orig_m = SuperNestedM2(3)
        ep_list = self._check_equal_ts_ep_converter(orig_m, inp)

        t = inp[0]
        t -= 0.8
        # Skip jit.traced because it specializes on one path.
        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(*inp),
                orig_m(*inp),
            )

    def test_ts2ep_converter_contains(self):
        class MIn(torch.nn.Module):
            def forward(self, x: torch.Tensor):
                return x.dtype in [torch.float32, torch.float64]

        class MNotIn(torch.nn.Module):
            def forward(self, x: torch.Tensor):
                return x.dtype in [torch.int8]

        class MTensorIn(torch.nn.Module):
            def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
                return x in x_dict

        # Traced function must return output that has tensors.
        inp = (torch.tensor(4),)
        self._check_equal_ts_ep_converter(MIn(), inp, ["script"])
        self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"])

        # TODO: update test to use reference for in.
        inp = (torch.tensor(4), {torch.tensor(4): "foo"})
        self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
        inp = (torch.tensor(1), {torch.tensor(4): "foo"})
        self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])

    def test_ts2ep_converter_custom_op(self):
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
            torch._dynamo.config.capture_scalar_outputs = True
            torch._dynamo.config.capture_dynamic_output_shape_ops = True

            torch.library.define(
                "mylib::foo",
                "(Tensor x) -> Tensor",
                lib=lib,
            )

            # PyTorch custorm op implementation
            @torch.library.impl(
                "mylib::foo",
                "CompositeExplicitAutograd",
                lib=lib,
            )
            def foo_impl(x):
                return x + x

            # Meta function of the custom op.
            @torch.library.impl_abstract(
                "mylib::foo",
                lib=lib,
            )
            def foo_meta(x):
                return x + x

            class M(torch.nn.Module):
                def forward(self, x):
                    return torch.ops.mylib.foo(x)

            inp = (torch.randn(3, 3),)
            m = M()
            self._check_equal_ts_ep_converter(m, inp)

    def test_convert_func_without_param(self):
        def func1(x, y):
            return x + y

        def func2(x, y):
            if x.sum() > 0:
                return x + y
            else:
                return x - y

        inp = (
            torch.tensor(1),
            torch.tensor(1),
        )
        self._check_equal_ts_ep_converter(func1, inp)

        ep_list = self._check_equal_ts_ep_converter(func2, inp)

        t = inp[0]
        t -= 1
        for ep in ep_list[1:]:
            torch.testing.assert_close(
                ep.module()(*inp),
                func2(*inp),
            )

    def test_implicit_constant_to_tensor_handling(self):
        def func1(x):
            return x + 2

        def func2(x, y):
            return x * y / (x - 2 * y) + y

        def func3(x):
            return x + torch.tensor([3])

        def func4():
            val = torch.tensor(float("inf"))
            return torch.full((10, 10), val)

        def func5():
            x = -1
            return x * torch.ones(1, dtype=torch.float), torch.zeros(
                1, dtype=torch.float
            )

        def func6(x1, x2, x3, x4):
            return (
                x1.numel(),
                x1.size(),
                x2.numel(),
                x2.size(),
                x3.numel(),
                x3.size(),
                x4.numel(),
                x4.size(),
                torch.ones(x1.numel()),  # Just make sure downstream ops still work.
                torch.ones(x1.size()),  # Just make sure downstream ops still work.
            )

        class M1(torch.nn.Module):
            def __init__(self, value):
                super().__init__()
                self.x = torch.tensor(value)

            def forward(self):
                return self.x.clone()

        class M2(torch.nn.Module):
            def forward(self, x):
                return torch.tensor(4) + x

        inp = (torch.randn([2, 2]),)
        self._check_equal_ts_ep_converter(func1, inp)
        inp = (torch.randn([2, 2]), torch.randn([2, 2]))
        self._check_equal_ts_ep_converter(func2, inp)

        inp = (torch.randn([2, 2]),)
        self._check_equal_ts_ep_converter(func3, inp)

        self._check_equal_ts_ep_converter(func4, ())
        self._check_equal_ts_ep_converter(M1(5), ())

        inp = (torch.randn(2),)
        self._check_equal_ts_ep_converter(M2(), inp)

        self._check_equal_ts_ep_converter(func5, ())
        inp = (
            torch.randn([2, 3, 4]).to(torch.int8),
            torch.randn([2, 3, 4]).to(torch.int32),
            torch.randn([2, 3, 4]).to(torch.float32),
            torch.randn([2, 3, 4]).to(torch.float64),
        )
        ep_list = self._check_equal_ts_ep_converter(func6, inp)

        # TODO: Additional check once dynamic shape is supported.
        # for ep in ep_list:
        #     self.assertEqual(
        #         ep.module()(
        #             torch.randn([1, 1, 1]).to(torch.int8),
        #             torch.randn([1, 1, 1]).to(torch.int32),
        #             torch.randn([1, 1, 1]).to(torch.float32),
        #             torch.randn([1, 1, 1]).to(torch.float64),
        #         )[0], 1
        #     )

    def test_aten_tensor_dtype_int(self):
        class M(torch.nn.Module):
            def forward(self, x):
                y = torch.tensor(1, dtype=torch.int32)
                return y + x

        ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
        for ep in ep_list:
            self.assertEqual(len(ep.constants), 1)

    def test_aten_tensor_prim_dtype(self):
        class M(torch.nn.Module):
            def forward(self, x):
                y = torch.tensor(1, dtype=x.dtype)
                return y + x

        ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
        for ep in ep_list:
            self.assertEqual(len(ep.constants), 1)

    def test_aten_tensor_dynamic(self):
        class M(torch.nn.Module):
            def forward(self, x):
                s = x.shape[0]
                y = torch.tensor(s)
                return y

        ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
        for ep in ep_list:
            self.assertEqual(len(ep.constants), 0)

        # TODO: Additional check once dynamic shape is supported.
        # for ep in ep_list:
        #     torch.testing.assert_close(
        #         ep.module()(torch.ones(4)),
        #         M()(torch.ones(4)),
        #     )

        class M(torch.nn.Module):
            def forward(self, x):
                s = x.shape[0]
                y = torch.tensor([s, s * 2, 1])
                return y

        ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
        # Trace directly inline a tensor constant.
        for ep in ep_list[1:]:
            self.assertEqual(len(ep.constants), 0)

        # TODO: Additional check once dynamic shape is supported.
        # for ep in ep_list:
        #     torch.testing.assert_close(
        #         ep.module()(torch.ones(4)),
        #         M()(torch.ones(4)),
        #     )

    def test_prim_tolist(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> List[int]:
                return x.tolist()

        inp = (torch.tensor([1, 2, 3]),)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> List[List[int]]:
                return x.tolist()

        inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

    def test_get_tensor_constants(self):
        # Since self.data is only read but not written, it is lifted as
        # constant tensors.
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.data = torch.randn(3, 2)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x + self.data

        class Goo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.data = torch.randn(3, 2)
                self.foo = Foo()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x + self.data + self.foo.data + self.foo(x)

        inp = (torch.randn(3, 2),)
        goo = Goo()
        self._check_equal_ts_ep_converter(goo, inp)

    def test_prim_SetAttr(self):
        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.data = torch.nn.Buffer(torch.ones(3, 2))

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                self.data = self.data + x
                return x + x

        inp = (torch.ones(3, 2),)
        self._check_equal_ts_ep_converter(
            Module, inp, ["script"], check_persistent=True
        )

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.data = torch.nn.Buffer(torch.ones(3, 2))

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                self.data = self.data + x
                return x + self.data

        inp = (torch.ones(3, 2),)
        self._check_equal_ts_ep_converter(
            Module, inp, ["script"], check_persistent=True
        )

        # export lifts a tensor constant (self.data) as an input if it is not assigned.
        # If it is assigned, export will error and ask users to register it as a buffer.
        # In converter, we change tensor constants that are assigned as a buffer automatically,
        # since it might be hard to manually register them as buffers.
        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.data = torch.ones(3, 2)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                self.data = self.data + x
                return x + self.data

        inp = (torch.ones(3, 2),)
        self._check_equal_ts_ep_converter(
            Module,
            inp,
            ["script"],
            check_persistent=True,
            lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]),
        )

        class Module(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.count = 0

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                self.count += 1
                return x + self.count

        # check_persistent is False since export specializes on non-tensor constants
        inp = (torch.ones(3, 2),)
        self._check_equal_ts_ep_converter(
            Module(), inp, ["script"], check_persistent=False
        )

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.count = 0

            def forward(self, x):
                count1 = self.count
                self.count += 1
                count2 = self.count
                self.count += 1
                count3 = self.count
                return x + count1 + count2 + count3

        inp = (torch.ones(1),)
        self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False)

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.w2 = torch.nn.Buffer(torch.ones(1))

            def forward(self, x: torch.Tensor):
                self.w2 += 1
                return self.w2

        inp = (torch.ones(1),)
        self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True)

    def test_raise_exception(self):
        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
                if y > 0:
                    raise RuntimeError("test")
                return x + y

        # match non-strict export behavior that errors when the given input leads to
        # RaiseException.
        with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
            inp = (torch.randn(3, 2), 1)
            self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        # Matching non-strict export behavior that only executes 1 if-branch according
        # to the given input.
        inp = (torch.randn(3, 2), 0)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        class Module(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
                z = x
                if y > 0:
                    raise RuntimeError("test")
                    # z = x
                else:
                    z = x + y
                return x + y + z

        # match non-strict export behavior that errors when the given input leads to
        # RaiseException.
        with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
            inp = (torch.randn(3, 2), 1)
            self._check_equal_ts_ep_converter(Module(), inp, ["script"])

        # Matching non-strict export behavior that only executes 1 if-branch according
        # to the given input.
        inp = (torch.randn(3, 2), 0)
        self._check_equal_ts_ep_converter(Module(), inp, ["script"])

    def test_context_manager(self):
        class ContextManager:
            def __init__(self) -> None:
                self.count = 0
                return

            def __enter__(self):
                self.count += 1
                return

            def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
                self.count -= 1
                return

        class M(torch.nn.Module):
            def forward(self, x, y):
                with ContextManager():
                    res = x + y
                return res

        inp = (torch.ones(3, 3), torch.ones(3, 3))
        self._check_equal_ts_ep_converter(M(), inp)

    def test_hidden_input_name(self):
        @torch.jit.script
        def func1(x):
            return x + 1

        def func2(*args):
            v = torch.cat(args, dim=1)
            return v * v

        inp = (torch.randn([1, 1]),)
        self._check_equal_ts_ep_converter(func1, inp)

        inp = (torch.ones(5, 5),)
        # Cannot script again.
        self._check_equal_ts_ep_converter(torch.ops.aten.relu, inp, ["trace"])

        M = 2
        Ns = [4, 2, 1]
        empty = torch.tensor([], dtype=torch.double)
        values = [empty] + [torch.randn(M, N) for N in Ns]
        # Cannot script variable length inputs.
        self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"])

    def test_ts2ep_multi_outputs_on_call_ops(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True)

            def forward(self, x: torch.Tensor, y: torch.Tensor):
                return (
                    torch.max(x, dim=0),
                    torch.topk(x, 3),
                    torch.sort(x, dim=0),
                    self.pool(y),
                )

        inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10]))
        self._check_equal_ts_ep_converter(M(), inp)

    def test_aten_append_t(self):
        class M(torch.nn.Module):
            def forward(self, x: List[torch.Tensor]):
                out = []
                out.append(x[0] + x[1])
                out.append(x[0] - x[1])
                out1 = torch.cat(out)
                out.append(x[0] * x[1])
                out2 = torch.cat(out)
                return out, out1, out2

        inp = ([torch.ones(2, 3), torch.ones(2, 3)],)
        # Trace already unrolls the list.
        self._check_equal_ts_ep_converter(M(), inp, ["script"])

    def test_convert_script_object(self):
        class M1(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.tq = _empty_tensor_queue()

            def forward(self, x: torch.Tensor):
                self.tq.push(x)
                torch.ops._TorchScriptTesting.queue_push(self.tq, x.cos())
                return torch.ops._TorchScriptTesting.queue_pop(self.tq), self.tq.pop()

        inp = (torch.randn(2, 3),)
        self._check_equal_ts_ep_converter(M1(), inp, ["script"])

    def test_ts2ep_with_loop(self):
        def func1(x, x_list: List[torch.Tensor]):
            a, b, c = x, x, x
            for i in range(1, 5, 2):
                for k in range(5):
                    a = a + a + k
                    b = b + b - k
                    x_list.append(x_list[k] + x_list[k + 1])
                for k in range(5):
                    b = b + b - k
                    c = c + c * k
                    x_list.append(x_list[k] + x_list[k + 1] - x_list[k + 2])
            return x, x_list

        def func2(x):
            for i in range(x.size(0)):
                x = x * x * i
            return x

        def func3(x):
            while x.sum() < 10:
                x += x.sin()
            return x

        inp = (
            torch.tensor(1),
            [torch.ones([2, 2]), torch.ones([2, 2]) * 2],
        )
        # Trace unrolls the loop.
        self._check_equal_ts_ep_converter(func1, inp, ["script"])

        # TODO: (2/N)
        # Trace unrolls the loop.
        # self._check_equal_ts_ep_converter(func2, inp, ["script"])

        # TODO: (3/N)
        # Trace unrolls the loop.
        # self._check_equal_ts_ep_converter(func3, inp, ["script"])

    @unittest.skipIf(
        IS_WINDOWS,
        "Windows does not support qnnpack",
    )
    def test_ts2ep_convert_quantized_model(self):
        class Standalone(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.quant = torch.ao.quantization.QuantStub()
                self.conv1 = torch.nn.Conv2d(1, 1, 1)
                self.conv2 = torch.nn.Conv2d(1, 1, 1)
                self.relu = torch.nn.ReLU()
                self.dequant = torch.ao.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.conv2(x)
                x = self.relu(x)
                x = self.dequant(x)
                return x

            def fuse_model(self):
                torch.ao.quantization.fuse_modules(
                    self, [["conv2", "relu"]], inplace=True
                )

        with override_quantized_engine("qnnpack"):
            model = Standalone()
            model.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
            model.fuse_model()
            torch.ao.quantization.prepare(model, inplace=True)
            model(torch.randn(4, 1, 4, 4))
            torch.ao.quantization.convert(model, inplace=True)

            # Use customized checking here, because state_dict of quantization will be
            # modified by the quantization pass.
            inp = (torch.randn(4, 1, 4, 4),)
            original_ts_model = torch.jit.script(model)
            ts_model = torch.jit.script(model)
            converter = TS2EPConverter(ts_model, inp)
            ep = converter.convert()

            orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
            ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
            self._check_tensor_list_equal(orig_out, ep_out)

    def test_ts2ep_convert_quantized_model_with_opcontext(self):
        class M(torch.nn.Module):
            def __init__(self, linear_op):
                super().__init__()
                self.linear_op = linear_op

            def forward(self, x):
                x = torch.ops.prepacked.linear_clamp_run(x, self.linear_op)
                return x

        linear_op = torch.ops.prepacked.linear_clamp_prepack(
            torch.randn(10, 10), torch.randn(10)
        )
        m = M(linear_op)
        inp = (torch.randn(1, 10),)
        self._check_equal_ts_ep_converter(m, inp, ["script"])


if __name__ == "__main__":
    run_tests()
