# Owner(s): ["module: fx"]

import copy
import unittest
from typing import Set, Type

import torch
import torch.fx
from torch.testing._internal.common_utils import IS_MACOS, TestCase


class TestDCE(TestCase):
    def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
        if node.is_impure():
            return True
        # a custom function that defines add operators as impure.
        if node.target == torch.ops.aten.add:
            return True
        return False

    def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
        for node in m.graph.nodes:
            if (not custom and node.is_impure()) or (
                custom and self._custom_is_impure_node(node)
            ):
                continue
            if len(node.users) == 0:
                return True
        return False

    def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
        count = 0
        for node in m.graph.nodes:
            if node.op == "placeholder":
                count += 1
        return count

    def _run_dce_and_test(
        self,
        m: torch.nn.Module,
        expect_dce_changes: bool,
        modules_to_be_leafs: Set[Type] = None,
        custom: bool = False,
    ):
        class TestTracer(torch.fx.Tracer):
            def is_leaf_module(self, m, qualname):
                if modules_to_be_leafs and type(m) in modules_to_be_leafs:
                    return True
                return super().trace(m, qualname)

        traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
        print(str(traced.graph))

        # Verify there are nodes without users (if expected).
        has_nodes_without_users = self._has_nodes_without_users(traced, custom=custom)
        if expect_dce_changes:
            self.assertTrue(has_nodes_without_users)
        else:
            self.assertFalse(has_nodes_without_users)

        # Get the original number of placeholders to verify it doesn't change
        # during DCE.
        orig_num_phs = self._get_num_placeholders(traced)
        if custom:
            changed = traced.graph.eliminate_dead_code(
                is_impure_node=self._custom_is_impure_node
            )
        else:
            changed = traced.graph.eliminate_dead_code()

        self.assertTrue(changed if expect_dce_changes else not changed)

        # Verify there are no nodes without users after DCE is run.
        self.assertFalse(self._has_nodes_without_users(traced, custom=custom))
        new_num_phs = self._get_num_placeholders(traced)
        self.assertEqual(orig_num_phs, new_num_phs)

        traced.recompile()
        # Make sure we run and get the same results before/after DCE.
        inputs = [torch.tensor([1.5])] * new_num_phs
        inputs_copy = copy.deepcopy(inputs)
        self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy)))

    def test_simple(self):
        """
        Tests that a single node in the graph is DCE'd correctly.
        """

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))

            def forward(self, x):
                a = x + 1
                return x + self.attr_1

        self._run_dce_and_test(TestModule(), expect_dce_changes=True)

    def test_dead_chain(self):
        """
        Tests that a chain of two nodes in the graph are DCE'd correctly.
        """

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))

            def forward(self, x):
                a = x + 1
                b = a * 7
                return x + self.attr_1

        self._run_dce_and_test(TestModule(), expect_dce_changes=True)

    def test_dead_getattr(self):
        """
        Tests that a getatrr in the graph is DCE'd correctly.
        """

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))

            def forward(self, x):
                a = x + 1
                b = a * self.attr_1
                return x + 11

        self._run_dce_and_test(TestModule(), expect_dce_changes=True)

    def test_dead_placeholder(self):
        """
        Tests that a placeholder in the graph is not DCE'd, as that would change
        the function signature.
        """

        class TestModule(torch.nn.Module):
            def forward(self, x, y):
                return x + 7

        self._run_dce_and_test(TestModule(), expect_dce_changes=False)

    def test_dead_placeholder_with_user(self):
        """
        Tests that a placeholder in the graph is not DCE'd, as that would change
        the function signature. Also verifies that a dead node that uses the
        placeholder is DCE'd.

        """

        class TestModule(torch.nn.Module):
            def forward(self, x, y):
                a = y + 2
                return x + 7

        self._run_dce_and_test(TestModule(), expect_dce_changes=True)

    def test_keep_module_with_side_effects(self):
        """
        Test that DCE doesn't remove a module if it's specified as having side effects.
        """

        class ReLUImpure(torch.nn.ReLU):
            _is_impure = True

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = ReLUImpure()

            def forward(self, a: torch.Tensor) -> torch.Tensor:
                r = self.relu(a)
                return a * 2

        self._run_dce_and_test(
            TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
        )

    def test_keep_torch_assert(self):
        """
        Test that DCE doesn't remove torch._assert since it has side effects.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a: torch.Tensor) -> torch.Tensor:
                torch._assert(torch.equal(a, a), "a must equal a")
                return a * 2

        # Note: Don't need to specify torch._assert as having side effects
        # because it's known to.
        self._run_dce_and_test(TestModule(), expect_dce_changes=False)

    def test_impure_nodes_args(self):
        """
        Test that DCE doesn't remove call_function nodes with side effects.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a: torch.Tensor) -> torch.Tensor:
                torch._ops.ops.aten.add_.Tensor(a, 1)
                return a * 2

        # %add_ node should not be removed because it has side effects.
        self._run_dce_and_test(TestModule(), expect_dce_changes=False)

    def test_impure_kwargs(self):
        """
        Test that DCE doesn't remove call_function nodes with side effects on kwargs.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a: torch.Tensor) -> torch.Tensor:
                b = a + 1
                torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
                return a

        # %add_out node should not be removed because it has side effects.
        self._run_dce_and_test(TestModule(), expect_dce_changes=False)

    def test_impure_custom(self):
        """
        Test that DCE doesn't remove nodes marked as impure by a custom function.
        """

        class TestModule(torch.nn.Module):
            def forward(self, a: torch.Tensor) -> torch.Tensor:
                b = a + 1
                c = torch._ops.ops.aten.add(b, b)
                return a

        # %add_out node should not be removed because it has side effects.
        self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)

    @unittest.skipIf(IS_MACOS, "Not working on macos")
    def test_keep_collectives(self):
        """
        Test that DCE doesn't remote collective ops even the results are not used.
        """

        from torch.testing._internal.distributed.fake_pg import FakeStore

        class TestModule(torch.nn.Module):
            def forward(
                self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
            ) -> torch.Tensor:
                d = torch.ops.aten.mul.Tensor(a, b)
                e = torch.ops.aten.mul.Tensor(a, c)
                future = torch.ops._c10d_functional.all_reduce.default(e, "sum", "0")
                synced_e = torch.ops._c10d_functional.wait_tensor.default(
                    future
                )  # synced_e is not used
                return d

        torch.distributed.init_process_group(
            backend="fake",
            world_size=2,
            rank=0,
            store=FakeStore(),
        )
        # collective nodes should not be removed because they have side effects.
        self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
        torch.distributed.destroy_process_group()

    @unittest.skipIf(IS_MACOS, "Not working on macos")
    def test_keep_collectives_no_overload(self):
        """
        Test that DCE doesn't remote collective ops (no overload version) even the results are not used.
        """

        from torch.testing._internal.distributed.fake_pg import FakeStore

        class TestModule(torch.nn.Module):
            def forward(
                self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
            ) -> torch.Tensor:
                d = torch.ops.aten.mul(a, b)
                e = torch.ops.aten.mul(a, c)
                future = torch.ops._c10d_functional.all_reduce(e, "sum", "0")
                synced_e = torch.ops._c10d_functional.wait_tensor(
                    future
                )  # synced_e is not used
                return d

        torch.distributed.init_process_group(
            backend="fake",
            world_size=2,
            rank=0,
            store=FakeStore(),
        )
        # collective nodes should not be removed because they have side effects.
        self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
        torch.distributed.destroy_process_group()
