# Owner(s): ["module: fx.passes"]

from dataclasses import dataclass
import operator
import logging
import sys

import torch
from torch.fx._symbolic_trace import symbolic_trace

from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
from torch.fx.passes.operator_support import OperatorSupport
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher

from torch.testing._internal.common_utils import run_tests, parametrize, instantiate_parametrized_tests
from torch.testing._internal.jit_utils import JitTestCase

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger(__name__)

class TestModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(4, 4)
        self.linear2 = torch.nn.Linear(4, 4)
        self.param = torch.nn.Parameter(torch.rand(4, 4))

    def forward(self, a, b, c):
        add = a + b

        linear_1 = self.linear(add)

        add_1 = add + c
        add_2 = add_1 + self.param
        add_3 = add_1 + linear_1
        add_4 = add_2 + add_3

        linear_2 = self.linear2(add_4)

        add_5 = linear_2 + add_4
        add_6 = add_5 + a
        relu = add_6.relu()

        return add_4, add_6, relu

class TestDeepModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(4, 4)

    def forward(self, a, b, c):
        o = a + b
        o = o + 1.0

        # testing to avoid DFS uses in passes. Since Python has max recursion depth.
        for _ in range(sys.getrecursionlimit() + 1):
            o = o - c

        return o


class TestPartitionFunctions:
    @staticmethod
    def forward1(a, b, c):
        add = a + b
        add_1 = add + b
        add_2 = add_1 + c
        relu_1 = add_2.relu()
        add_3 = add_1 + add_2
        add_4 = add_1 + relu_1 + add_3
        relu_2 = add_4.relu()
        add_5 = relu_2 + add_4
        add_6 = add_5 + add_4
        return add_4, add_6

    @staticmethod
    def forward2(a, b, _):
        add = a + b
        add_1 = add + b
        relu_1 = add_1.relu()  # blocked by this
        add_3 = add_1 + relu_1
        add_4 = add_1 + add_3
        return add_4, add_1

    @staticmethod
    def forward3(a, b, c):
        add = a + b
        add_1 = a + c
        add_2 = b + c
        return add, add_1, add_2

    @staticmethod
    def forward4(a, b, c):
        add = a + b
        add_1 = a + c
        add_2 = b + c
        return torch.where(add > 0, add_1, add_2)

    @staticmethod
    def forward5(a, b, c):
        # add should be fused right branch, as left branch is not supported
        add = a + 1
        # left branch
        relu = add.relu()
        # right branch
        add_1 = add + 2
        return relu, add_1

    @staticmethod
    def forward6(a, b, c):
        # add should have its own partition, as neither branchs are supported
        add = a + 1
        # left branch
        relu = add.relu()
        # right branch
        relu_1 = add.relu()
        return relu, relu_1

    @staticmethod
    def forward7(a, b, c):
        # both branches are supported, all adds should be fused together
        add = a + 1
        # left branch
        add_1 = add + 2
        # right branch is larger
        add_2 = add + 1
        add_3 = add_2 + 1
        return add_3, add_1

    @staticmethod
    def forward8(a, b, c):
        # both branches are in the same partition, add should join the same partition
        add = a + 1
        # left branch
        add_1 = add + 2
        # right branch
        add_2 = add + 1
        # left and right branch merges
        add_3 = add_2 + add_1

        return add_3

    @staticmethod
    def forward9(a, b, c):
        add = a + 1
        # branch 1
        add_1 = add + 1
        # branch 2
        add_2 = add + 1
        # branch_3
        add_3 = add + 1
        out = torch.stack([add_1, add_2, add_3])
        return out

    @staticmethod
    def forward10(a, b, c):
        add = a + 1
        # branch 1
        add_1 = add + 1
        # branch 2
        add_2 = add + 1
        # branch 3: depends on branch 2
        add_3 = add + add_2
        out = torch.stack([add_1, add_2, add_3])
        return out

    @staticmethod
    def forward11(a, b, c):
        add = a + 1
        # branch 1
        add_1 = add.relu()
        # branch 2 depends on branch 1
        add_2 = add + add_1
        # branch 3
        add_3 = add.relu()
        out = torch.stack([add_1, add_2, add_3])
        return out

    @staticmethod
    def forward12(a, b, c):
        b0 = a + 1.0
        c0 = a + 1.5
        x0 = b0.relu()
        x1 = c0.relu()
        b1 = b0 + x1
        c1 = c0 + 1.2
        # c2 has dependency on x0 & b0, when we merge {c0, c1, c2}
        # this dependency should be updated to the fusion group and reflected
        # on the decision to not fuse b0 & b1, which forms a cyclic dependency in
        # the new graph
        c2 = x0 + c0
        return b1, c2

    @staticmethod
    def forward13(a, b, c):
        a0, a1, a2, a3 = a.split(1, 0)
        b1 = a0 + b
        c1 = a1 + c
        return b1 + c1

    @staticmethod
    def forward14(a, b, c):
        a0, a1 = torch.ops.aten.std_mean(a)
        out = a0 + 1.0
        return out

    @staticmethod
    def forward15(a, b, c):
        a0 = torch.ops.aten.view(a, [2, 2])
        a1 = torch.ops.aten.permute(a0, [1, 0])
        a2 = a1 + 1.0
        a3 = torch.ops.aten.permute(a2, [1, 0])
        a4 = a3 + 1.0
        a5 = torch.ops.aten.permute(a4, [1, 0])
        return torch.ops.aten.permute(a5, [1, 0])

    @staticmethod
    def forward16(a, b, c):
        a0 = a - 1.0
        a1 = torch.ops.aten.view(a0, [2, 2])
        a2 = torch.ops.aten.permute(a1, [1, 0])
        a3 = a2 + 1.0
        a4 = torch.ops.aten.permute(a3, [1, 0])
        a5 = a4 + 1.0
        a6 = torch.ops.aten.permute(a5, [1, 0])
        a7 = torch.ops.aten.permute(a6, [1, 0])
        return a7 - 1.0

    @staticmethod
    def forward17(a, b, c, d, e, f):
        a0 = a + b
        a1 = c + d
        a2 = e + f
        return a0, a1, a2

    @staticmethod
    def forward18(a, b, c):
        a0, a1 = torch.ops.aten.var_mean(a)
        return a0

# A mock OperatorSupport class, where only operator.add is supported
class MockOperatorSupport(OperatorSupport):
    def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
        return (node.op == "call_function" and
                node.target in {operator.add, operator.getitem,
                                torch.ops.aten.view,
                                torch.ops.aten.permute,
                                torch.ops.aten.std_mean})

@instantiate_parametrized_tests
class TestFXGraphPasses(JitTestCase):

    @parametrize("fn, expected_partition, bookend_non_compute_pass", [
        (TestPartitionFunctions.forward1, [["add_7", "add_6"], ["add_5", "add_4", "add_3"], ["add_2", "add_1", "add"]], False),
        (TestPartitionFunctions.forward2, [["add_3", "add_2"], ["add_1", "add"]], False),

        # 1 horizontal fusion with common producer
        (TestPartitionFunctions.forward3, [["add_2", "add_1", "add"]], False),
        (TestPartitionFunctions.forward4, [["add_2", "add_1", "add"]], False),

        # 2 branches cases
        (TestPartitionFunctions.forward5, [["add_1", "add"]], False),
        (TestPartitionFunctions.forward6, [["add"]], False),
        (TestPartitionFunctions.forward7, [["add_3", "add_2", "add", "add_1"]], False),
        (TestPartitionFunctions.forward8, [["add_3", "add_2", "add", "add_1"]], False),

        # 3 branch cases
        (TestPartitionFunctions.forward9, [['add_3', 'add_2', 'add_1', 'add']], False),
        (TestPartitionFunctions.forward10, [['add_3', 'add_2', 'add', 'add_1']], False),
        (TestPartitionFunctions.forward11, [['add_1'], ['add']], False),

        # 4 not necessarily the only partition, just to verify that there's no cyclic dependency after partition
        (TestPartitionFunctions.forward12, [["add_2", "add_3", "add_4"], ["add", "add_1"]], False),

        # 5 getitem special case
        (TestPartitionFunctions.forward13, [["add_2", "add_1", "add"]], False),
        (TestPartitionFunctions.forward14, [["add", "std_mean", "getitem", "getitem_1"]], False),

        # 6 bookend non_compute pass
        (TestPartitionFunctions.forward15, [["permute_1", "add_1", "add"]], True),
        (TestPartitionFunctions.forward15, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
        (TestPartitionFunctions.forward16, [["permute_1", "add_1", "add"]], True),
        (TestPartitionFunctions.forward16, [['add_1', 'add', 'permute_1', 'view', 'permute_2', 'permute_3', 'permute']], False),
        # should be empty partition, not a partiton with empty nodes
        (TestPartitionFunctions.forward18, [], False),
    ])
    def test_partitioner(self, fn, expected_partition, bookend_non_compute_pass):
        traced = symbolic_trace(fn)

        non_compute_ops = []
        if bookend_non_compute_pass:
            non_compute_ops = ["torch.ops.aten.view", "torch.ops.aten.permute"]

        supported_ops = MockOperatorSupport()
        partitioner = CapabilityBasedPartitioner(traced,
                                                 supported_ops,
                                                 allows_single_node_partition=True,
                                                 non_compute_ops=non_compute_ops)
        partitions = partitioner.propose_partitions()
        if bookend_non_compute_pass:
            partitioner.remove_bookend_non_compute_ops(partitions)

        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
        assert len(partitions_name) == len(expected_partition)
        for i in range(len(partitions_name)):
            assert set(partitions_name[i]) == set(expected_partition[i])

        fused_graph = partitioner.fuse_partitions(partitions)

        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)

        expected = fn(a, b, c)
        result = fused_graph(a, b, c)
        torch.testing.assert_close(expected, result)

    @parametrize("fn, expected_partition", [
        (TestPartitionFunctions.forward17, [['add', 'add_1', 'add_2']]),
    ])
    def test_partitioner_independent_output(self, fn, expected_partition):
        traced = symbolic_trace(fn)

        supported_ops = MockOperatorSupport()
        partitioner = CapabilityBasedPartitioner(traced,
                                                 supported_ops,
                                                 allows_single_node_partition=True)
        partitions = partitioner.propose_partitions()
        partitions_name = [[node.name for node in partition.nodes] for partition in partitions]
        assert len(partitions_name) == len(expected_partition)
        for i in range(len(partitions_name)):
            assert set(partitions_name[i]) == set(expected_partition[i])

        fused_graph = partitioner.fuse_partitions(partitions)

        a, b, c, d, e, f = torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4), torch.rand(4)

        expected = fn(a, b, c, d, e, f)
        result = fused_graph(a, b, c, d, e, f)
        torch.testing.assert_close(expected, result)

    @parametrize("partition", [
        [['add', 'add_1'], ['add_5', 'add_6']],
        [['add', 'add_1', 'add_2']],  # vertical fusion
        [['add_2', 'add_3']],         # horizontal fusion
        [['add_3', 'add_4']],
        [['add_6', 'add_5']],     # arbitray node order
        [['add_4', 'add_1', 'add_3', 'add_2']],           # arbitray node order
        [['add_5', 'add_6'], ['add_1', 'add_2', 'add_3', 'add_4']],  # arbitray partition order
        [['add_5', 'linear2']],   # includes call_function + call_module node
        [['add_6', 'relu']],   # includes call_function + call_module node
        [['param', 'add_2']],   # includes get_attr + call_module nodes
        [['param', 'add_1', 'linear']],   # includes get_attr + call_function + call_module nodes
        [["add", "linear", "add_1", "param", "add_2", "add_3", "add_4", "linear2", "add_5", "add_6", "relu"]],  # full graph
    ])
    def test_fuser_util(self, partition):
        m = TestModule()
        gm = symbolic_trace(m)

        nodes_by_name = {node.name : node for node in gm.graph.nodes}

        partitions = []
        for node_names in partition:
            partitions.append([nodes_by_name[name] for name in node_names])

        fused_graph = fuse_by_partitions(gm, partitions)

        a, b, c = torch.rand(4), torch.rand(4), torch.rand(4)

        expected = m(a, b, c)
        result = fused_graph(a, b, c)

        torch.testing.assert_close(expected, result)

    @parametrize("partition", [
        [['add', 'add_1'], ['add_1', 'add_5', 'add_6']],  # add_1 exists in multiple partitions
        [['add', 'add_1', 'add_3']],    # invalid partition: circular dependency
        [['add_4', 'add_5']],    # invalid partition: circular dependency
        [['relu', 'add_5']],    # invalid partition: circular dependency
    ])
    def test_fuser_util_xfail(self, partition):
        m = TestModule()
        gm = symbolic_trace(m)

        nodes_by_name = {node.name : node for node in gm.graph.nodes}

        partitions = []
        for node_names in partition:
            partitions.append([nodes_by_name[name] for name in node_names])

        with self.assertRaises(Exception):
            fuse_by_partitions(gm, partitions)

    def test_fuser_pass_deep_model(self):
        m = TestDeepModule()
        traced = symbolic_trace(m)

        supported_ops = MockOperatorSupport()
        partitioner = CapabilityBasedPartitioner(traced,
                                                 supported_ops,
                                                 allows_single_node_partition=True)
        partitions = partitioner.propose_partitions()

@dataclass
class TestCase:
    match_output: bool
    match_placeholder: bool
    num_matches: int
    remove_overlapping_matches: bool = True

class SingleNodePattern:
    @staticmethod
    def forward(x):
        val = torch.neg(x)
        return torch.add(val, val)

    @staticmethod
    def pattern(a):
        return torch.neg(a)

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 0),
        TestCase(False, True, 1),
        TestCase(True, True, 0)
    ]
class SimplePattern:
    @staticmethod
    def forward(x, w1, w2):
        m1 = torch.cat([w1, w2]).sum()
        m2 = torch.cat([w2, w1]).sum()
        m3 = torch.cat([m1, m2]).sum()
        return x + torch.max(m1) + torch.max(m2) + m3

    @staticmethod
    def pattern(a, b):
        return torch.cat([a, b]).sum()

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 3),
        TestCase(True, False, 0),
        TestCase(False, True, 2),
        TestCase(True, True, 0)
    ]

class SimpleFullGraphMatching:
    @staticmethod
    def forward(x):
        a = torch.neg(x)
        return torch.add(a, a)

    @staticmethod
    def pattern(x):
        a = torch.neg(x)
        return torch.add(a, a)

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 1),
        TestCase(False, True, 1),
        TestCase(True, True, 1)
    ]

class DiamondShapePatternTestCase:
    @staticmethod
    def forward(x):
        a = torch.neg(x)

        a = a.relu()
        left = a.sigmoid()
        right = a.relu()
        out = left + right

        return out

    @staticmethod
    def pattern(a):
        a = a.relu()
        left = a.sigmoid()
        right = a.relu()
        out = left + right
        return out

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 1),
        TestCase(False, True, 0),
        TestCase(True, True, 0)
    ]

class NonFullyContainedMatches:
    @staticmethod
    def forward(x, w1, w2, b1, b2):
        # fully contained matched subgraph
        m1 = torch.cat([w1, w2])
        m2 = torch.cat([x, b2])
        t0 = torch.addmm(b1, m1, m2.t())
        t0_sum = torch.sum(t0)   # use of t0 is not leaking

        # leaking matched subgraph, m3 is leaked
        m3 = torch.cat([w1, w2])
        m4 = torch.cat([x, b2])
        t1 = torch.addmm(b1, m3, m4.t())
        m3_sum = torch.sum(m3)

        return t0_sum, m3_sum

    @staticmethod
    def pattern(x, w1, w2, b1, b2):
        m1 = torch.cat([w1, w2])
        m2 = torch.cat([x, b2])
        return torch.addmm(b1, m1, m2.t())

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),

        TestCase(True, False, 0),

        TestCase(False, True, 1),     # leaked used of placeholder is not leaking
    ]

class ChainRepeatedPattern:
    @staticmethod
    def forward(x):
        x = torch.sigmoid(x)
        x = torch.sigmoid(x)
        x = torch.sigmoid(x)
        return torch.sigmoid(x)

    @staticmethod
    def pattern(x):
        return torch.sigmoid(torch.sigmoid(x))

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 3, remove_overlapping_matches=False),
        TestCase(False, False, 2, remove_overlapping_matches=True),
        TestCase(True, False, 1),
        TestCase(False, True, 1),
        TestCase(True, True, 0)
    ]

class QuantizationModel:
    @staticmethod
    def forward(x):
        x += 3
        x = x.dequantize()
        x = torch.sigmoid(x)
        x = x.to(torch.float16)
        return x

    @staticmethod
    def pattern(x):
        x = x.dequantize()
        x = torch.sigmoid(x)
        x = x.to(torch.float16)
        return x

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 1),
        TestCase(False, True, 0),
        TestCase(True, True, 0)
    ]

class MultipleOutputsWithDependency:
    @staticmethod
    def forward(x):
        y = x.relu()
        z = y.sigmoid()
        return z, y

    @staticmethod
    def pattern(a):
        b = a.relu()
        c = b.sigmoid()
        return b, c     # outputs have data dependency

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 0),
        TestCase(False, True, 1),
        TestCase(True, True, 0)
    ]

class MultipleOutputsWithoutDependency:
    @staticmethod
    def forward(x):
        x = x + 1

        # target subgraph to match
        x = x.relu()
        z = x.sum()
        y = x.sigmoid()

        out = y.sigmoid() + z.sum()
        return out

    @staticmethod
    def pattern(a):
        a = a.relu()
        b = a.sigmoid()
        c = a.sum()
        return b, c

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 0),
        TestCase(False, True, 0),
        TestCase(True, True, 0)
    ]

class MultipleOutputsMultipleOverlappingMatches:
    @staticmethod
    def forward(x):
        x = x + 1

        # target subgraph to match
        x = x.relu()
        z = x.sum()
        z1 = x.sum()
        y = x.sigmoid()
        y1 = x.sigmoid()

        return z + z1 + y + y1

    @staticmethod
    def pattern(a):
        a = a.relu()
        b = a.sigmoid()
        c = a.sum()
        return a, b, c

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 4, remove_overlapping_matches=False),
        TestCase(False, False, 1, remove_overlapping_matches=True),
    ]

class MultipleOutputsMultipleNonOverlappingMatches:
    @staticmethod
    def forward(x):
        x = x + 1

        # target subgraph to match
        x = x.relu()
        z = x.sum()
        y = x.sigmoid()

        x = x.relu()
        z1 = x.sum()
        y1 = x.sigmoid()

        return z + z1 + y + y1

    @staticmethod
    def pattern(a):
        a = a.relu()
        b = a.sigmoid()
        c = a.sum()
        return b, c

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
    ]

class MultipleOutputsIdenticalAnchor:
    @staticmethod
    def forward(x):
        x = x + 1

        # target subgraph to match
        x = x.relu()
        y = x.sigmoid()
        y1 = x.sigmoid()

        return y, y1

    @staticmethod
    def pattern(a):
        a = a.relu()
        b = a.sigmoid()
        b1 = a.sigmoid()
        return b, b1

    test_cases = [
        # match_output, match_placeholder, num_matches
        # (False, False, 2),  # FIXME: currently still matches to 2, should fix to 1
        TestCase(True, False, 1),
        TestCase(False, True, 0),
    ]


class MultipleOutputsHorizontalPattern:
    @staticmethod
    def forward(x):
        x = x + 1

        # target subgraph to match
        y1 = x.relu()
        y2 = x.sigmoid()

        return y1, y2

    @staticmethod
    def pattern(a):
        b1 = a.relu()
        b2 = a.sigmoid()

        return b1, b2

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
        TestCase(True, False, 1),
        TestCase(False, True, 0),
        TestCase(True, True, 0)
    ]

class MultiOutputWithWithInvalidMatches:
    @staticmethod
    def forward(x):
        res0 = torch.nn.functional.linear(x, torch.rand(3, 3))
        res1 = torch.sigmoid(res0)
        res2 = res0 * res1
        res3 = torch.sum(res2, dim=1)
        return res3

    @staticmethod
    def pattern(a, b, c):
        lin_res = torch.nn.functional.linear(a, b)
        mul_res = lin_res * c
        return lin_res, mul_res

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 0),
        TestCase(True, False, 0),
        TestCase(False, True, 0),
    ]

class QuantizationFp8Pattern:
    @classmethod
    def setup(cls):
        cls.quantization = torch.library.Library("fp8_quantization", "DEF")  # noqa: TOR901
        cls.quantization.define("quantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")
        cls.quantization.define("dequantize_per_tensor_affine_fp8(Tensor self, int dtype, float scale) -> Tensor")

    @classmethod
    def tearDown(cls):
        del cls.quantization

    @staticmethod
    def forward(self, arg0_1, arg1_1):
        qt = torch.ops.fp8_quantization
        _scale_0 = self._scale_0
        quantize_per_tensor_affine_fp8 = qt.quantize_per_tensor_affine_fp8(arg0_1, 0, _scale_0)
        dequantize_per_tensor_affine_fp8 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8, 0, _scale_0)
        _scale_1 = self._scale_0
        quantize_per_tensor_affine_fp8_1 = qt.quantize_per_tensor_affine_fp8(arg1_1, 0, _scale_1)
        dequantize_per_tensor_affine_fp8_1 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_1, 0, _scale_1)
        add = torch.ops.aten.add.Tensor(dequantize_per_tensor_affine_fp8, dequantize_per_tensor_affine_fp8_1)
        _scale_2 = self._scale_0
        quantize_per_tensor_affine_fp8_2 = qt.quantize_per_tensor_affine_fp8(add, 0, _scale_2)
        dequantize_per_tensor_affine_fp8_2 = qt.dequantize_per_tensor_affine_fp8(quantize_per_tensor_affine_fp8_2, 0, _scale_2)
        return dequantize_per_tensor_affine_fp8_2

    @staticmethod
    def pattern(a, a_dtype, a_scale, b, b_dtype, b_scale, out_scale):
        qt = torch.ops.fp8_quantization
        a = qt.dequantize_per_tensor_affine_fp8(a, a_dtype, a_scale)
        b = qt.dequantize_per_tensor_affine_fp8(b, b_dtype, b_scale)
        output = torch.ops.aten.add.Tensor(a, b)

        qt.dequantize_per_tensor_affine_fp8

        output = qt.quantize_per_tensor_affine_fp8(output, a_dtype, out_scale)
        return output

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 1),
    ]

class NoAnchorFound:
    # This test case is for pattern where no matching anchor is found in the target graph
    # `anchor` is the starting point of the pattern matching, it's usually the boundary returning nodes
    @staticmethod
    def forward(x):
        x = x + 1
        return x

    @staticmethod
    def pattern(a):
        b1 = a.relu()
        return b1

    test_cases = [
        # match_output, match_placeholder, num_matches
        TestCase(False, False, 0),
        TestCase(True, False, 0),
        TestCase(False, True, 0),
        TestCase(True, True, 0)
    ]

@instantiate_parametrized_tests
class TestFXMatcherUtils(JitTestCase):

    @parametrize("test_model", [
        SingleNodePattern,
        SimplePattern,
        SimpleFullGraphMatching,
        DiamondShapePatternTestCase,
        NonFullyContainedMatches,
        ChainRepeatedPattern,
        QuantizationModel,
        MultipleOutputsWithDependency,
        MultipleOutputsWithoutDependency,
        MultipleOutputsMultipleOverlappingMatches,
        MultipleOutputsMultipleNonOverlappingMatches,
        MultipleOutputsIdenticalAnchor,
        MultipleOutputsHorizontalPattern,
        MultiOutputWithWithInvalidMatches,
        QuantizationFp8Pattern,
        NoAnchorFound,
    ])
    def test_subgraph_matcher(self, test_model):

        setup = getattr(test_model, "setup", None)
        if callable(setup):
            setup()

        traced = symbolic_trace(test_model.forward)
        pattern_traced = symbolic_trace(test_model.pattern)

        for test_case in test_model.test_cases:

            matcher = SubgraphMatcher(pattern_traced.graph,
                                      match_output=test_case.match_output,
                                      match_placeholder=test_case.match_placeholder,
                                      remove_overlapping_matches=test_case.remove_overlapping_matches)
            matches = matcher.match(traced.graph)

            assert len(matches) == test_case.num_matches

            for match in matches:
                for node in pattern_traced.graph.nodes:
                    if not test_case.match_placeholder and node.op == "placeholder":
                        continue
                    if not test_case.match_output and node.op == "output":
                        continue
                    assert node in match.nodes_map

        tearDown = getattr(test_model, "tearDown", None)
        if callable(setup):
            tearDown()


if __name__ == "__main__":
    run_tests()
