# Owner(s): ["module: fx"]

import os
import sys
import unittest

import torch


pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch._dynamo.eval_frame import is_dynamo_supported
from torch.fx.passes.tools_common import legalize_graph
from torch.fx.passes.utils.source_matcher_utils import (
    check_subgraphs_connected,
    get_source_partitions,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
)
from torch.testing._internal.jit_utils import JitTestCase


class TestSourceMatcher(JitTestCase):
    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    def test_module_partitioner_linear_relu_linear(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = torch.nn.Linear(3, 3)
                self.relu = torch.nn.ReLU()
                self.linear2 = torch.nn.Linear(3, 5)

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear1(x)
                x = self.relu(x)
                x = self.linear2(x)
                return x

        inputs = (torch.randn(3, 3),)
        gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
        gm.graph.eliminate_dead_code()

        module_partitions = get_source_partitions(
            gm.graph, [torch.nn.Linear, torch.nn.ReLU]
        )

        self.assertEqual(len(module_partitions), 2)
        self.assertEqual(len(module_partitions[torch.nn.Linear]), 3)
        self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)

        self.assertFalse(
            check_subgraphs_connected(
                module_partitions[torch.nn.Linear][0],
                module_partitions[torch.nn.ReLU][0],
            )
        )
        self.assertTrue(
            check_subgraphs_connected(
                module_partitions[torch.nn.Linear][1],
                module_partitions[torch.nn.ReLU][0],
            )
        )
        self.assertFalse(
            check_subgraphs_connected(
                module_partitions[torch.nn.Linear][2],
                module_partitions[torch.nn.ReLU][0],
            )
        )

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    def test_module_partitioner_conv_relu_maxpool(self):
        class M(torch.nn.Module):
            def __init__(self, constant_tensor: torch.Tensor) -> None:
                super().__init__()
                self.constant_tensor = constant_tensor
                self.conv1 = torch.nn.Conv2d(
                    in_channels=3, out_channels=16, kernel_size=3, padding=1
                )
                self.conv2 = torch.nn.Conv2d(
                    in_channels=16, out_channels=16, kernel_size=3, padding=1
                )
                self.conv3 = torch.nn.Conv2d(
                    in_channels=16, out_channels=16, kernel_size=3, padding=1
                )
                self.relu = torch.nn.ReLU()
                self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                a = self.conv1(x)
                b = self.conv2(a)
                c = a + self.constant_tensor
                z = self.conv3(b + c)
                return self.maxpool(self.relu(z))

        inputs = (torch.randn(1, 3, 256, 256),)
        gm, _ = torch._dynamo.export(M(torch.ones(1, 16, 256, 256)), aten_graph=True)(
            *inputs
        )
        gm.graph.eliminate_dead_code()

        module_partitions = get_source_partitions(
            gm.graph, [torch.nn.Conv2d, torch.nn.ReLU, torch.nn.MaxPool2d]
        )

        self.assertEqual(len(module_partitions), 3)
        self.assertEqual(len(module_partitions[torch.nn.Conv2d]), 3)
        self.assertEqual(len(module_partitions[torch.nn.ReLU]), 1)
        self.assertEqual(len(module_partitions[torch.nn.MaxPool2d]), 1)

        self.assertFalse(
            check_subgraphs_connected(
                module_partitions[torch.nn.Conv2d][0],
                module_partitions[torch.nn.ReLU][0],
            )
        )
        self.assertFalse(
            check_subgraphs_connected(
                module_partitions[torch.nn.Conv2d][1],
                module_partitions[torch.nn.ReLU][0],
            )
        )
        self.assertTrue(
            check_subgraphs_connected(
                module_partitions[torch.nn.Conv2d][2],
                module_partitions[torch.nn.ReLU][0],
            )
        )
        self.assertFalse(
            check_subgraphs_connected(
                module_partitions[torch.nn.MaxPool2d][0],
                module_partitions[torch.nn.ReLU][0],
            )
        )
        self.assertTrue(
            check_subgraphs_connected(
                module_partitions[torch.nn.ReLU][0],
                module_partitions[torch.nn.MaxPool2d][0],
            )
        )

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    def test_module_partitioner_functional_conv_relu_conv(self):
        class FunctionalConv2d(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight, bias):
                return torch.nn.functional.conv2d(
                    x,
                    weight,
                    bias,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.groups,
                )

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = FunctionalConv2d()
                self.conv2 = FunctionalConv2d()

            def forward(self, x, weight, bias):
                x = self.conv1(x, weight, bias)
                x = torch.nn.functional.relu(x)
                x = self.conv2(x, weight, bias)
                return x

        inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
        gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
        gm.graph.eliminate_dead_code()

        module_partitions = get_source_partitions(
            gm.graph, [torch.nn.functional.conv2d]
        )

        self.assertEqual(len(module_partitions), 1)
        self.assertEqual(len(module_partitions[torch.nn.functional.conv2d]), 2)

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    def test_module_partitioner_functional_linear_relu_linear(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x, weight, bias):
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.relu(x)
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.relu(x)
                return x

        inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
        gm, _ = torch._dynamo.export(M(), aten_graph=True)(*inputs)
        gm.graph.eliminate_dead_code()

        module_partitions = get_source_partitions(
            gm.graph, [torch.nn.functional.linear, torch.nn.functional.relu]
        )

        self.assertEqual(len(module_partitions), 2)
        self.assertEqual(len(module_partitions[torch.nn.functional.linear]), 4)
        self.assertEqual(len(module_partitions[torch.nn.functional.relu]), 2)

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    def test_legalize_slice(self):
        class M(torch.nn.Module):
            def forward(self, x, y):
                b = x.item()
                torch._check_is_size(b)
                torch._check(b + 1 < y.size(0))
                return y[: b + 1]

        ep = torch.export.export(M(), (torch.tensor(4), torch.randn(10)))
        fake_inputs = [
            node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder"
        ]
        gm = ep.module()
        with fake_inputs[0].fake_mode:
            torch.fx.Interpreter(gm).run(*fake_inputs)
        legalized_gm = legalize_graph(gm)
        with fake_inputs[0].fake_mode:
            torch.fx.Interpreter(legalized_gm).run(*fake_inputs)

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    @parametrize("strict", (True, False))
    def test_module_partitioner_linear_relu_linear_torch_fn_export(self, strict: bool):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear1 = torch.nn.Linear(3, 3)
                self.relu = torch.nn.ReLU()
                self.linear2 = torch.nn.Linear(3, 5)

            def forward(self, x):
                x = self.linear1(x)
                x = self.linear1(x)
                x = self.relu(x)
                x = self.linear2(x)
                return x

        inputs = (torch.randn(3, 3),)
        gm = torch.export.export(M(), inputs, strict=strict).module()
        gm.graph.eliminate_dead_code()

        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
        # TODO: remove this after we fix "torch_fn". T199561090
        for node in gm.graph.nodes:
            node.meta["source_fn_stack"] = None

        module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])

        self.assertEqual(len(module_partitions), 2)
        self.assertEqual(len(module_partitions["linear"]), 3)
        self.assertEqual(len(module_partitions["relu"]), 1)

        self.assertFalse(
            check_subgraphs_connected(
                module_partitions["linear"][0],
                module_partitions["relu"][0],
            )
        )
        self.assertTrue(
            check_subgraphs_connected(
                module_partitions["linear"][1],
                module_partitions["relu"][0],
            )
        )
        self.assertFalse(
            check_subgraphs_connected(
                module_partitions["linear"][2],
                module_partitions["relu"][0],
            )
        )

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    @parametrize("strict", (True, False))
    def test_module_partitioner_conv_relu_maxpool_torch_fn_export(self, strict: bool):
        class M(torch.nn.Module):
            def __init__(self, constant_tensor: torch.Tensor) -> None:
                super().__init__()
                self.constant_tensor = constant_tensor
                self.conv1 = torch.nn.Conv2d(
                    in_channels=3, out_channels=16, kernel_size=3, padding=1
                )
                self.conv2 = torch.nn.Conv2d(
                    in_channels=16, out_channels=16, kernel_size=3, padding=1
                )
                self.conv3 = torch.nn.Conv2d(
                    in_channels=16, out_channels=16, kernel_size=3, padding=1
                )
                self.relu = torch.nn.ReLU()
                self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                a = self.conv1(x)
                b = self.conv2(a)
                c = a + self.constant_tensor
                z = self.conv3(b + c)
                return self.maxpool(self.relu(z))

        inputs = (torch.randn(1, 3, 256, 256),)
        gm = torch.export.export(
            M(torch.ones(1, 16, 256, 256)), inputs, strict=strict
        ).module()
        gm.graph.eliminate_dead_code()

        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
        # TODO: remove this after we fix "torch_fn". T199561090
        for node in gm.graph.nodes:
            node.meta["source_fn_stack"] = None

        module_partitions = get_source_partitions(
            gm.graph, ["conv2d", "relu", "max_pool2d"]
        )

        self.assertEqual(len(module_partitions), 3)
        self.assertEqual(len(module_partitions["conv2d"]), 3)
        self.assertEqual(len(module_partitions["relu"]), 1)
        self.assertEqual(len(module_partitions["max_pool2d"]), 1)

        self.assertFalse(
            check_subgraphs_connected(
                module_partitions["conv2d"][0],
                module_partitions["relu"][0],
            )
        )
        self.assertFalse(
            check_subgraphs_connected(
                module_partitions["conv2d"][1],
                module_partitions["relu"][0],
            )
        )
        self.assertTrue(
            check_subgraphs_connected(
                module_partitions["conv2d"][2],
                module_partitions["relu"][0],
            )
        )
        self.assertFalse(
            check_subgraphs_connected(
                module_partitions["max_pool2d"][0],
                module_partitions["relu"][0],
            )
        )
        self.assertTrue(
            check_subgraphs_connected(
                module_partitions["relu"][0],
                module_partitions["max_pool2d"][0],
            )
        )

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    @parametrize("strict", (True, False))
    def test_module_partitioner_functional_conv_relu_conv_torch_fn_export(
        self, strict: bool
    ):
        class FunctionalConv2d(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.stride = (1, 1)
                self.padding = (0, 0)
                self.dilation = (1, 1)
                self.groups = 1

            def forward(self, x, weight, bias):
                return torch.nn.functional.conv2d(
                    x,
                    weight,
                    bias,
                    self.stride,
                    self.padding,
                    self.dilation,
                    self.groups,
                )

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = FunctionalConv2d()
                self.conv2 = FunctionalConv2d()

            def forward(self, x, weight, bias):
                x = self.conv1(x, weight, bias)
                x = torch.nn.functional.relu(x)
                x = self.conv2(x, weight, bias)
                return x

        inputs = (torch.randn(1, 3, 5, 5), torch.rand(3, 3, 3, 3), torch.rand(3))
        gm = torch.export.export(M(), inputs, strict=strict).module()
        gm.graph.eliminate_dead_code()

        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
        # TODO: remove this after we fix "torch_fn". T199561090
        for node in gm.graph.nodes:
            node.meta["source_fn_stack"] = None

        module_partitions = get_source_partitions(gm.graph, ["conv2d"])

        self.assertEqual(len(module_partitions), 1)
        self.assertEqual(len(module_partitions["conv2d"]), 2)

    @unittest.skipIf(not is_dynamo_supported(), "Dynamo not supported")
    @parametrize("strict", (True, False))
    def test_module_partitioner_functional_linear_relu_linear_torch_fn_export(
        self, strict: bool
    ):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(self, x, weight, bias):
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.relu(x)
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.linear(x, weight, bias)
                x = torch.nn.functional.relu(x)
                return x

        inputs = (torch.randn(1, 5), torch.rand((5, 5)), torch.zeros(5))
        gm = torch.export.export(M(), inputs, strict=strict).module()
        gm.graph.eliminate_dead_code()

        # Remove "source_fn_stack" meta to let partitioner use "torch_fn" only.
        # TODO: remove this after we fix "torch_fn". T199561090
        for node in gm.graph.nodes:
            node.meta["source_fn_stack"] = None

        module_partitions = get_source_partitions(gm.graph, ["linear", "relu"])

        self.assertEqual(len(module_partitions), 2)
        self.assertEqual(len(module_partitions["linear"]), 4)
        self.assertEqual(len(module_partitions["relu"]), 2)


instantiate_parametrized_tests(TestSourceMatcher)
