# Owner(s): ["module: fx"]

import operator

import torch
import torch.fx
from torch.fx.experimental import const_fold
from torch.fx.passes.shape_prop import _extract_tensor_metadata, ShapeProp
from torch.testing._internal.common_utils import TestCase


class TestConstFold(TestCase):
    def _get_attr(self, node):
        mod = node.graph.owning_module
        target = str(node.target)
        target_atoms = target.split(".")
        curr_obj = mod
        for i, atom in enumerate(target_atoms):
            if not hasattr(curr_obj, atom):
                raise RuntimeError(
                    f"Node referenced nonexistent target '{'.'.join(target_atoms[:i])}'; "
                    f" original whole target: '{target}'"
                )
            curr_obj = getattr(curr_obj, atom)
        return curr_obj

    def _verify_const_fold_mod(self, mod_folded: const_fold.FoldedGraphModule):
        self.assertTrue(mod_folded.const_subgraph_module is not None)

        # Check that we don't have the const or non-const fold graphs in the gm, and
        # that we do have the const folded get_attr.
        found_folded_attrs = False
        for n in mod_folded.graph.nodes:
            if n.op == "get_attr" and n.target.startswith("_FX_CONST_FOLDED_ATTRS"):
                found_folded_attrs = True
            elif n.op == "call_module":
                self.assertTrue(n.target not in {"submod_0", "submod_1"})
        self.assertTrue(found_folded_attrs)

    def test_const_fold_basic_one_attr_no_name_collision(self):
        r"""
        Perform constant folding conversion, from original mod to split constant folding
        module with two split subgraphs, where there's a single attr to fold and
        a single output attr result to replace.

           attr1                 attr1
            | |                   | |
        x   add                   add
         \ /                       |
         sub   y                 output     (becomes attr add_1)
            \ /         ==> -------+------- (const/base subgraph split)
            mul  attr2       x   /          (input from previous subgraph
              \ /             \ /            is attr)
              add             sub   y
               |                 \ /
             output              mul  attr2
                                   \ /
                                   add
                                    |
                                  output
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
                self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]))

            def forward(self, x, y):
                a = self.attr_1 + self.attr_1
                x = x - a
                return x * y + self.attr_2

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
        base_result = mod(in_x, in_y)
        fold_result = mod_folded(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_basic_one_attr_name_collision(self):
        r"""
        Perform constant folding conversion, from original mod to split constant folding
        module with two split subgraphs, where there's a single attr to fold and
        a single output attr result to replace. Name the attrs such that they will
        collide by name with folded attrs.

           add_1                 add_1
            | |                   | |
        x   add                   add
         \ /                       |
         sub   y                 output     (becomes attr add_1)
            \ /         ==> -------+------- (const/base subgraph split)
            mul  add_2       x   /          (input from previous subgraph
              \ /             \ /            is attr)
              add             sub   y
               |                 \ /
             output              mul  add_2
                                   \ /
                                   add
                                    |
                                  output
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                # Note: Named as such to result in name collision.
                self.add_1__CF = torch.nn.Parameter(torch.tensor([[1.0]]))
                self.add_2__CF = torch.nn.Parameter(torch.tensor([[17.1]]))

            def forward(self, x, y):
                a = self.add_1__CF + self.add_1__CF
                x = x - a
                return x * y + self.add_2__CF

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        in_x, in_y = torch.tensor([[5.0]]), torch.tensor([4.0])
        base_result = mod(in_x, in_y)
        fold_result = mod_folded(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_basic_placeholder_reordered(self):
        """
        Test code path where placeholder comes after normal op node in FX
        """

        class ConstFoldTestModule(torch.nn.Module):
            def forward(self, x, y):
                return x * 2 + y

        mod = ConstFoldTestModule()
        mod = torch.fx.symbolic_trace(mod)
        yy = None
        for n in mod.graph.nodes:
            if n.op == "placeholder" and n.target == "y":
                yy = n
            elif yy is not None and n.op == "call_function":
                yy.prepend(n)
                break

        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)

        self.assertTrue(mod_folded.const_subgraph_module is None)
        # Now run both folded and non-folded to check results equal.
        in_x = torch.tensor([[-0.45]])
        in_y = torch.tensor([[0.45]])
        base_result = mod(in_x, in_y)
        fold_result = mod_folded(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_noop(self):
        r"""
        Check that a graph with no constant folding is handled correctly.

        x  attr1
         \ /
         sub
          |
        output
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]]))

            def forward(self, x):
                return x - self.attr1

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)

        # Check that the folded graph module is None, since there was no folding to do.
        self.assertTrue(mod_folded.const_subgraph_module is None)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.tensor([[-0.45]])
        base_result = mod(in_x)
        fold_result = mod_folded(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_basic_two_attr_three_input(self):
        r"""
        Perform constant folding conversion, from original mod to split constant
        folding module with two split subgraphs, where there are two attrs to
        fold into a single output, and there are three placeholder inputs.

        attr1   attr2         attr1   attr2
            \   /                 \   /
         x   add                   add
          \ /                       |
          sub     y               output     (becomes attr add_1)
             \   /     ==>   -------+------- (const/base subgraph split)
              mul  z           x   /         (input from previous subgraph
                \ /             \ /           is attr)
                div              sub  y
                 |                 \ /
               output              mul  z
                                     \ /
                                     div
                                      |
                                    output
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr1 = torch.nn.Parameter(torch.tensor([[-0.9]]))
                self.attr1 = torch.nn.Parameter(torch.tensor([[1.32]]))

            def forward(self, x, y, z):
                a = self.attr1 + self.attr1
                sub = x - a
                mul = sub * y
                return mul / z

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        in_x, in_y, in_z = (
            torch.tensor([[-0.45]]),
            torch.tensor([0.9]),
            torch.tensor([1.1]),
        )
        base_result = mod(in_x, in_y, in_z)
        fold_result = mod_folded(in_x, in_y, in_z)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_basic_two_attr(self):
        r"""
        Perform constant folding conversion, from original mod to split constant
        folding module with two split subgraphs, where there are two attrs to
        fold into a single output.

        attr1  attr2                attr1  attr2
            \ /                         \ /
        x   add                         add       (becomes attr add_1)
         \ /            ==>       -------+------- (const/base subgraph split)
         sub                         x   |        (input from previous subgraph is attr)
          |                           \ /
        output                        sub
                                       |
                                     output
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr1 = torch.nn.Parameter(torch.randn(2, 3))
                self.attr2 = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                y = self.attr1 + self.attr2
                return x + y

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = mod_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_multi_const_folded_attrs(self):
        r"""
        Perform constant folding conversion, from original mod to split constant
        folding module with two split subgraphs, where there are two attrs to
        fold into two new attrs.

           attr1        attr2          attr1     attr2
           /    \         |           /     \      |
        permute  |       sum       permute   |    sum
            \   /        /                \ /      |
         x   add    y   /                 add      |
          \ /        \ /                   |       |
          sub        add                 output  output     (become attrs add_1 and mul_1)
             \       /        ==>   --------+-------+------ (const/base subgraph split)
              \     /                   x   |   y   |       (inputs from previous subgraph
                add                      \ /     \ /         are attrs)
                 |                       sub     add
               linear                       \   /
                 |                           add
               sigmoid                        |
                 |                          linear
               output                         |
                                            sigmoid
                                              |
                                            output
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr1 = torch.nn.Parameter(torch.randn(4, 4))
                self.attr2 = torch.nn.Parameter(torch.randn(4, 4))
                self.lin = torch.nn.Linear(4, 4)

            def forward(self, x, y):
                a = self.attr1 + self.attr1.permute(1, 0)
                x = x - a
                amax = torch.sum(self.attr2, dim=1)
                y = y + amax
                return torch.sigmoid(self.lin(x + y))

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        in_x, in_y = torch.randn(4, 4), torch.randn(4)
        fold_result = mod_folded(in_x, in_y)
        base_result = mod(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_submod_hierarchy(self):
        r"""
        Perform constant folding conversion, from original mod to split constant folding
        module where one of the folded attrs comes from a submod deeper in the hierarchy
        of the base module.
        """

        class TracedThroughModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.internal_attr = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self):
                return self.internal_attr

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.my_mod = TracedThroughModule()
                self.attr = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                return self.attr + self.my_mod() + x

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = mod_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_retain_node_meta(self):
        r"""
        Perform constant folding conversion, and validate that node meta is retained.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.attr + self.attr
                return x - a

        mod = ConstFoldTestModule()
        gm = torch.fx.symbolic_trace(mod)

        # Add a count for each node to check after we const fold.
        for idx, node in enumerate(gm.graph.nodes):
            if node.op != "output":
                node.meta["meta_idx"] = idx

        # Pre-folding:
        # idx 0: placeholder
        # idx 1: get_attr (will no longer be used, hence removed)
        # idx 2: add (will be folded into a get_attr)
        # idx 3: sub

        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
        self._verify_const_fold_mod(gm_folded)

        # Post-folding:
        # idx 0: placeholder
        # idx 2: get_attr (replaced original add; original get_attr was removed)
        # idx 3: sub

        # Check the expected indices are still here.
        for node in gm_folded.graph.nodes:
            if node.op == "placeholder":
                self.assertEqual(node.meta["meta_idx"], 0)
            elif node.op == "get_attr":
                self.assertEqual(node.meta["meta_idx"], 2)
            elif node.op == "call_function" and node.target == operator.sub:
                self.assertEqual(node.meta["meta_idx"], 3)
            else:
                self.assertEqual(node.op, "output")

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_has_inlined_call_module_node(self):
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.nn.Parameter(torch.randn(2, 3))
                self.mod = torch.nn.Identity()
                self.mod.relu = torch.nn.ReLU()

            def forward(self, x):
                a = self.attr + self.attr
                return self.mod.relu(x - a)

        mod = ConstFoldTestModule()
        gm_folded = const_fold.split_const_subgraphs(mod)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_module_attr(self):
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.const = torch.nn.Parameter(torch.randn(2, 3))
                self.mod = torch.nn.Identity()
                self.mod.attr = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.const + self.mod.attr
                x = x + a
                return x + self.mod.attr

        mod = ConstFoldTestModule()
        gm_folded = const_fold.split_const_subgraphs(mod)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_const_fold_unused_placeholder(self):
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.const = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x, y, z):
                a = self.const + self.const
                return y + a

        mod = ConstFoldTestModule()
        gm_folded = const_fold.split_const_subgraphs(mod)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x, in_x, in_x)
        base_result = mod(in_x, in_x, in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_dict_output(self):
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.const = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.const + self.const
                return {"result": x + a}

        mod = ConstFoldTestModule()
        gm_folded = const_fold.split_const_subgraphs(mod)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result["result"], base_result["result"]))

    def test_two_outputs(self):
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.const = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.const + self.const
                return x, x + a

        mod = ConstFoldTestModule()
        gm_folded = const_fold.split_const_subgraphs(mod)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result[0], base_result[0]))
        self.assertTrue(torch.equal(fold_result[1], base_result[1]))

    def test_three_outputs(self):
        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.const = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.const + self.const
                return x, x + a, x + a

        mod = ConstFoldTestModule()
        gm_folded = const_fold.split_const_subgraphs(mod)

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result[0], base_result[0]))
        self.assertTrue(torch.equal(fold_result[1], base_result[1]))
        self.assertTrue(torch.equal(fold_result[2], base_result[2]))

    def test_check_inline_non_const(self):
        r"""
        Perform constant folding conversion and check that the non-const module is inlined
        correctly.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.attr + self.attr
                return (x - a * x) / 2

        mod = ConstFoldTestModule()
        gm = torch.fx.symbolic_trace(mod)

        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
        self._verify_const_fold_mod(gm_folded)

        # Check there are no call modules, because they've been inlined or extracted for
        # const folding.
        for node in gm_folded.graph.nodes:
            self.assertNotEqual(node.op, "call_module")

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_check_inline_non_const_mult_return(self):
        r"""
        Perform constant folding conversion and check that the non-const module is inlined
        correctly.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr = torch.nn.Parameter(torch.randn(2, 3))

            def forward(self, x):
                a = self.attr + self.attr
                return x - a, x / 2

        mod = ConstFoldTestModule()
        gm = torch.fx.symbolic_trace(mod)

        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(gm)
        self._verify_const_fold_mod(gm_folded)

        # Check there are no call modules, because they've been inlined or extracted for
        # const folding.
        for node in gm_folded.graph.nodes:
            self.assertNotEqual(node.op, "call_module")

        # Now run both folded and non-folded to check results equal.
        in_x = torch.randn(2, 3)
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result[0], base_result[0]))
        self.assertTrue(torch.equal(fold_result[1], base_result[1]))

    def test_check_skip_folding_quant_dequant_pattern(self):
        r"""
        Set up skip_folding_quant_dequant function to skip quant/dequant pattern.
        This example shows how to use skip_folding_node_fn.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.nn.Parameter(torch.randn(4, 4))
                self.bias = torch.nn.Parameter(torch.randn(4))
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                quant_weight = torch.quantize_per_tensor(
                    self.weight, 0.5, 3, torch.quint8
                )
                dequant_weight = torch.dequantize(quant_weight)
                output = torch.nn.functional.linear(x, dequant_weight, self.bias)
                return self.relu(output)

        mod = ConstFoldTestModule()
        in_x = torch.randn(2, 4)
        gm = torch.fx.symbolic_trace(mod)

        def skip_folding_quant_dequant(node: torch.fx.Node):
            if node.target != torch.quantize_per_tensor:
                return False
            # If quantize_per_node -> dequantize, then skip folding.
            for user in node.users:
                if user.target == torch.dequantize:
                    return True
            return False

        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
            gm, skip_folding_node_fn=skip_folding_quant_dequant
        )

        # Check that the folded graph module is None, since there was no folding to do.
        self.assertTrue(gm_folded.const_subgraph_module is None)

        # Now run both folded and non-folded to check results equal.
        fold_result = gm_folded(in_x)
        base_result = mod(in_x)
        self.assertTrue(torch.equal(fold_result, base_result))

    def test_fold_module(self):
        r"""
        Perform constant folding with a call_module node.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin_input = torch.nn.Parameter(torch.randn(4, 4))
                self.lin = torch.nn.Linear(4, 4)

            def forward(self, x):
                return self.lin(self.lin_input) + x

        mod = ConstFoldTestModule()
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(mod)
        self._verify_const_fold_mod(mod_folded)

        # Now run both folded and non-folded to check results equal.
        inp = torch.randn(4, 4)
        self.assertTrue(torch.equal(mod_folded(inp), mod(inp)))

    def test_const_fold_tensor_meta(self):
        self._test_const_fold_tensor_meta(True)
        self._test_const_fold_tensor_meta(False)

    def _test_const_fold_tensor_meta(self, requires_grad):
        """
        Verify tensor_meta is handled correctly.
        """

        class ConstFoldTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([[-0.9]]), requires_grad)
                self.attr_2 = torch.nn.Parameter(torch.tensor([[17.1]]), requires_grad)

            def forward(self, x, y):
                a = self.attr_1 + self.attr_1
                x = x - a
                return x * y + self.attr_2

        mod = ConstFoldTestModule()
        gm = torch.fx.symbolic_trace(mod)
        in_x, in_y = torch.tensor([[-0.45]]), torch.tensor([0.9])
        ShapeProp(gm).propagate(in_x, in_y)
        mod_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
            gm, device_for_folded_attrs="cpu"
        )
        self._verify_const_fold_mod(mod_folded)

        mod_folded.run_folding()

        for n in mod_folded.graph.nodes:
            if n.op == "get_attr":
                attr = self._get_attr(n)
                self.assertEqual(_extract_tensor_metadata(attr), n.meta["tensor_meta"])

        # Now run both folded and non-folded to check results equal.
        base_result = mod(in_x, in_y)
        fold_result = mod_folded(in_x, in_y)
        self.assertTrue(torch.equal(fold_result, base_result))
