# Owner(s): ["oncall: export"]
import unittest
from typing import Any, Dict, Optional, OrderedDict, Tuple

import torch
from torch._export.passes.lift_constants_pass import (
    ConstantAttrMap,
    lift_constants_pass,
)
from torch.export._unlift import _unlift_exported_program_lifted_states
from torch.export.exported_program import (
    ExportGraphSignature,
    InputKind,
    InputSpec,
    OutputKind,
    OutputSpec,
    TensorArgument,
)
from torch.export.graph_signature import CustomObjArgument
from torch.testing._internal.common_utils import (
    find_library_location,
    IS_FBCODE,
    IS_MACOS,
    IS_SANDCASTLE,
    IS_WINDOWS,
    run_tests,
    TestCase,
)


class GraphBuilder:
    def __init__(self) -> None:
        self.graph = torch.fx.Graph()
        self.nodes = {}
        self.values = {}
        self.nn_module_stack_key: Dict[str, int] = {}
        self.latest_id = 0
        self.input_to_kind: Dict[torch.fx.Node, InputKind] = {}

    def input(self, name: str, value: torch.Tensor, kind: InputKind):
        node = self.graph.placeholder(name)
        node.meta["val"] = value
        self.nodes[name] = node
        self.values[name] = value
        self.input_to_kind[node] = kind

    def add(self, x: str, y: str, out: str, module_fqn: str = ""):
        node = self.graph.create_node(
            "call_function",
            torch.ops.aten.add.Tensor,
            (self.nodes[x], self.nodes[y]),
            name=out,
        )
        self.values[out] = self.values[x] + self.values[y]
        node.meta["val"] = self.values[out]
        node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
        self.nodes[out] = node

    def call_function(self, target, args, out: str, module_fqn: str = ""):
        arg_nodes = tuple(self.nodes[arg] for arg in args)
        arg_values = tuple(self.values[arg] for arg in args)
        node = self.graph.create_node(
            "call_function",
            target,
            arg_nodes,
            name=out,
        )
        self.values[out] = target(*arg_values)
        node.meta["val"] = self.values[out]
        node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
        self.nodes[out] = node

    def constant(
        self, name: str, value: Any, target: Optional[str] = None, module_fqn: str = ""
    ):
        if target is None:
            target = name
        node = self.graph.get_attr(target)
        node.meta["val"] = value
        node.meta["nn_module_stack"] = self.create_nn_module_stack(module_fqn)
        self.nodes[name] = node
        self.values[name] = value

    def output(self, out: str):
        self.graph.output(self.nodes[out])

    def create_nn_module_stack(
        self, module_fqn: str
    ) -> OrderedDict[int, Tuple[str, type]]:
        cur_name = ""
        nn_module_stack = OrderedDict()
        for atom in module_fqn.split("."):
            if cur_name == "":
                cur_name = atom
            else:
                cur_name = cur_name + "." + atom

            if cur_name not in self.nn_module_stack_key:
                id_counter = self.latest_id
                self.latest_id += 1
                self.nn_module_stack_key[cur_name] = id_counter
            else:
                id_counter = self.nn_module_stack_key[cur_name]

            nn_module_stack[id_counter] = (cur_name, torch.nn.Module)
        return nn_module_stack

    def create_input_specs(self):
        input_specs = []
        for node in self.graph.nodes:
            if node.op == "placeholder":
                input_specs.append(
                    InputSpec(
                        kind=self.input_to_kind[node],
                        arg=TensorArgument(name=node.name),
                        target=None,
                        persistent=(
                            True
                            if self.input_to_kind[node] == InputKind.BUFFER
                            else None
                        ),
                    )
                )
        return input_specs

    # NOTE: does not handle non-user-outputs atm
    def gen_graph_signature(self) -> ExportGraphSignature:
        output = [n for n in self.graph.nodes if n.op == "output"]
        assert len(output) == 1
        output = output[0]
        assert len(output.args) == 1, "multiple outputs NYI"

        return ExportGraphSignature(
            input_specs=self.create_input_specs(),
            output_specs=[
                OutputSpec(
                    kind=OutputKind.USER_OUTPUT,
                    arg=TensorArgument(name=n.name),
                    target=None,
                )
                for n in output.args
            ],
        )


class TestLift(TestCase):
    def setUp(self):
        if IS_MACOS:
            raise unittest.SkipTest("non-portable load_library call used in test")
        elif IS_SANDCASTLE or IS_FBCODE:
            torch.ops.load_library(
                "//caffe2/test/cpp/jit:test_custom_class_registrations"
            )
        elif IS_WINDOWS:
            lib_file_path = find_library_location("torchbind_test.dll")
            torch.ops.load_library(str(lib_file_path))
        else:
            lib_file_path = find_library_location("libtorchbind_test.so")
            torch.ops.load_library(str(lib_file_path))

    def test_lift_basic(self):
        builder = GraphBuilder()

        builder.input("param", torch.rand(2, 3), InputKind.PARAMETER)
        builder.input("buffer", torch.rand(2, 3), InputKind.BUFFER)
        builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
        builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT)

        builder.add("x", "y", out="foo")
        builder.add("foo", "param", out="bar")
        builder.add("bar", "buffer", out="baz")
        builder.constant("const_tensor", torch.rand(2, 3))
        builder.constant("const_obj", torch.classes._TorchScriptTesting._Foo(10, 20))
        builder.add("baz", "const_tensor", out="out")
        builder.call_function(
            torch.ops._TorchScriptTesting.takes_foo,
            ("const_obj", "x"),
            out="torchbind_out",
        )
        builder.add("out", "torchbind_out", out="final_out")
        builder.output("final_out")

        builder.graph.lint()
        graph = builder.graph
        const_tensor = builder.values["const_tensor"]
        const_obj = builder.values["const_obj"]

        root = {"const_tensor": const_tensor, "const_obj": const_obj}
        gm = torch.fx.GraphModule(root, graph)
        graph_signature = builder.gen_graph_signature()
        constants = lift_constants_pass(gm, graph_signature, {})
        gm.graph.lint()

        self.assertEqual(len(constants), 2)

        # The key of the constants table should match the fqn of the constant.
        # In this case, it's just the name of the constant, since the constant
        # is at the root submodule.
        # TODO(suo): we shouldn't hardcode these names in the test, this is an
        # internal detail of the pass.
        self.assertIn("lifted_tensor_0", constants)
        self.assertEqual(constants["lifted_tensor_0"], const_tensor)
        self.assertIn("lifted_custom_0", constants)
        self.assertEqual(constants["lifted_custom_0"], const_obj)

        # The constant node should be removed.
        getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
        self.assertEqual(len(getattr_nodes), 0)

        # The constant should be lifted to a placeholder node.
        placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
        self.assertEqual(len(placeholder_nodes), 6)

        # The lifted constant should be placed before user inputs but after params/buffers
        lifted_tensor_placeholder = placeholder_nodes[2]
        self.assertEqual(lifted_tensor_placeholder.target, "lifted_tensor_0")
        # It should have a val equivalent to the constant
        self.assertEqual(lifted_tensor_placeholder.meta["val"], const_tensor)

        lifted_obj_placeholder = placeholder_nodes[3]
        self.assertEqual(lifted_obj_placeholder.target, "lifted_custom_0")
        # It should have a val equivalent to the constant
        self.assertEqual(
            lifted_obj_placeholder.meta["val"],
            CustomObjArgument(
                name="lifted_custom_0",
                class_fqn="__torch__.torch.classes._TorchScriptTesting._Foo",
            ),
        )

        # Graph signature should have been mutated a way that reflects the placeholders.
        tensor_constant_input_spec = graph_signature.input_specs[2]
        self.assertEqual(tensor_constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
        self.assertIsInstance(tensor_constant_input_spec.arg, TensorArgument)
        self.assertEqual(
            tensor_constant_input_spec.arg.name, lifted_tensor_placeholder.name
        )

        obj_constant_input_spec = graph_signature.input_specs[3]
        self.assertEqual(obj_constant_input_spec.kind, InputKind.CUSTOM_OBJ)
        self.assertIsInstance(obj_constant_input_spec.arg, CustomObjArgument)
        self.assertEqual(obj_constant_input_spec.arg.name, lifted_obj_placeholder.name)

    def test_lift_nested(self):
        builder = GraphBuilder()
        builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
        builder.input("y", torch.rand(2, 3), InputKind.USER_INPUT)
        builder.input("z", torch.rand(2, 3), InputKind.USER_INPUT)

        builder.add("x", "y", out="foo")
        builder.add("foo", "z", out="bar", module_fqn="foo")
        builder.constant("const_tensor", torch.rand(2, 3), module_fqn="foo")
        builder.add("bar", "const_tensor", "out")
        builder.output("out")

        graph = builder.graph
        graph.lint()

        const_tensor = builder.values["const_tensor"]
        root = {"const_tensor": builder.values["const_tensor"]}

        graph_signature = builder.gen_graph_signature()
        gm = torch.fx.GraphModule(root, graph)

        constants = lift_constants_pass(gm, graph_signature, {})
        gm.graph.lint()

        self.assertEqual(len(constants), 1)

        # The key of the constants table should match the fqn of the constant.
        self.assertIn("foo.lifted_tensor_0", constants)
        self.assertEqual(constants["foo.lifted_tensor_0"], const_tensor)

        # The constant node should be removed.
        getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
        self.assertEqual(len(getattr_nodes), 0)

        # The constant should be lifted to a placeholder node.
        placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
        self.assertEqual(len(placeholder_nodes), 4)

        # The lifted constant should be placed before user inputs but after params/buffers
        lifted_constant_placeholder = placeholder_nodes[0]
        self.assertEqual(lifted_constant_placeholder.target, "lifted_tensor_0")

        # Graph signature should have been mutated a way that reflects the placeholders.
        constant_input_spec = graph_signature.input_specs[0]
        self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
        self.assertIsInstance(constant_input_spec.arg, TensorArgument)
        self.assertEqual(constant_input_spec.arg.name, lifted_constant_placeholder.name)

    def test_duplicate_constant_access(self):
        const = torch.rand(2, 3)
        const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)

        builder = GraphBuilder()
        builder.input("x", torch.rand(2, 3), InputKind.USER_INPUT)
        builder.constant("const_tensor", const, target="const_tensor")
        # loading the same target twice
        builder.constant("const_tensor2", const, target="const_tensor")

        # loading the same object twice with different targets
        builder.constant("const_obj", const_obj)
        builder.constant("const_obj2", const_obj)
        builder.call_function(
            torch.ops._TorchScriptTesting.takes_foo,
            ("const_obj", "x"),
            out="torchbind_out",
        )
        builder.call_function(
            torch.ops._TorchScriptTesting.takes_foo,
            ("const_obj2", "x"),
            out="torchbind_out2",
        )
        builder.add("x", "const_tensor", out="foo")
        builder.add("foo", "const_tensor2", out="tensor_out")
        builder.add("torchbind_out", "torchbind_out2", out="obj_out")
        builder.add("tensor_out", "obj_out", out="out")
        builder.output("out")
        graph = builder.graph
        graph.lint()

        input_specs = builder.create_input_specs()
        output_specs = [
            OutputSpec(
                kind=OutputKind.USER_OUTPUT,
                arg=TensorArgument(name=builder.nodes["out"].name),
                target=None,
            )
        ]
        graph_signature = ExportGraphSignature(input_specs, output_specs)

        root = {"const_tensor": const, "const_obj": const_obj, "const_obj2": const_obj}
        gm = torch.fx.GraphModule(root, graph)

        constants = lift_constants_pass(gm, graph_signature, {})
        gm.graph.lint()

        self.assertEqual(len(constants), 2)

        # All get_attr nodes should be removed
        getattr_nodes = [n for n in gm.graph.nodes if n.op == "get_attr"]
        self.assertEqual(len(getattr_nodes), 0)

        # There should only be two additional inputs (plus the existing user input)
        placeholder_nodes = [n for n in gm.graph.nodes if n.op == "placeholder"]
        self.assertEqual(len(placeholder_nodes), 3)

        # Graph signature should have been mutated a way that reflects the placeholders.
        self.assertEqual(len(graph_signature.input_specs), 3)
        constant_input_spec = graph_signature.input_specs[0]
        self.assertEqual(constant_input_spec.kind, InputKind.CONSTANT_TENSOR)
        self.assertIsInstance(constant_input_spec.arg, TensorArgument)

    def test_unlift_nonpersistent_buffer(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.register_buffer(
                    "non_persistent_buf", torch.zeros(1), persistent=False
                )

            def forward(self, x):
                self.non_persistent_buf.add_(1)
                return x.sum() + self.non_persistent_buf.sum()

        foo = Foo()
        exported = torch.export.export(foo, (torch.ones(5, 5),), strict=False)
        stateful_gm = _unlift_exported_program_lifted_states(exported)

        # Check the unlifted stateful_gm contains the original non-persistent buffer
        self.assertTrue(hasattr(stateful_gm, "non_persistent_buf"))
        non_persistent_buf = stateful_gm.get_buffer("non_persistent_buf")
        self.assertEqual(non_persistent_buf, foo.get_buffer("non_persistent_buf"))
        self.assertIn("non_persistent_buf", stateful_gm._non_persistent_buffers_set)
        self.assertNotIn("non_persistent_buf", stateful_gm.state_dict())


class ConstantAttrMapTest(TestCase):
    def setUp(self):
        if IS_MACOS:
            raise unittest.SkipTest("non-portable load_library call used in test")
        elif IS_SANDCASTLE or IS_FBCODE:
            torch.ops.load_library(
                "//caffe2/test/cpp/jit:test_custom_class_registrations"
            )
        elif IS_WINDOWS:
            lib_file_path = find_library_location("torchbind_test.dll")
            torch.ops.load_library(str(lib_file_path))
        else:
            lib_file_path = find_library_location("libtorchbind_test.so")
            torch.ops.load_library(str(lib_file_path))

    def test_dict_api(self):
        constant_attr_map = ConstantAttrMap()
        const_obj = torch.classes._TorchScriptTesting._Foo(10, 20)
        const_tensor = torch.ones(2, 3)
        constant_attr_map.add(const_obj, "foo.bar")
        constant_attr_map.add(const_tensor, "foo.bar.baz")
        self.assertEqual(len(constant_attr_map), 2)
        self.assertEqual(list(constant_attr_map), [const_obj, const_tensor])
        self.assertEqual(list(constant_attr_map.keys()), [const_obj, const_tensor])
        self.assertEqual(
            list(constant_attr_map.values()), [["foo.bar"], ["foo.bar.baz"]]
        )
        self.assertEqual(constant_attr_map[const_obj], ["foo.bar"])
        self.assertEqual(constant_attr_map[const_tensor], ["foo.bar.baz"])
        self.assertTrue(const_obj in constant_attr_map)
        with self.assertRaises(TypeError):
            constant_attr_map.add(1, "foo.bar")

        del constant_attr_map[const_obj]
        self.assertEqual(len(constant_attr_map), 1)


if __name__ == "__main__":
    run_tests()
