# Owner(s): ["module: fx"]

import functools
import math
import numbers
import operator
import pickle
import sys
import sympy
import tempfile
import unittest
from types import BuiltinFunctionType
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union

import torch
import torch.fx.experimental.meta_tracer
import torch.fx.experimental.optimization as optimization
from torch.fx._symbolic_trace import symbolic_trace
from torch.fx.experimental import merge_matmul
from torch.fx.experimental.accelerator_partitioner import Partitioner
from torch.fx.experimental.normalize import NormalizeArgs, NormalizeOperators
from torch.fx.experimental.partitioner_utils import (
    Device,
    get_latency_of_partitioned_graph,
    get_partition_to_latency_mapping,
    NodeLatency,
    PartitionerConfig,
    PartitionMode,
)
from torch.fx.experimental.rewriter import RewritingTracer
from torch.fx.experimental.schema_type_annotation import AnnotateTypesWithSchema
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from torch.fx.operator_schemas import (
    _torchscript_type_to_python_type,
    create_type_hint,
    normalize_function,
    normalize_module,
    type_matches,
)
from torch.fx.passes import graph_manipulation
from torch.fx.passes.param_fetch import lift_lowering_attrs_to_nodes
from torch.fx.passes.shape_prop import ShapeProp
from torch.fx.passes.split_module import split_module
from torch.fx.passes.annotate_getitem_nodes import annotate_getitem_nodes
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    onlyCPU,
    ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_nn import module_tests, new_module_tests
from torch.testing._internal.common_utils import TEST_Z3, run_tests, TestCase
from torch.testing._internal.jit_utils import JitTestCase
import torch.utils._pytree as pytree

try:
    import torchvision.models
    from torchvision.models import resnet18

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
skipIfNoMkldnn = unittest.skipIf(
    not (torch.backends.mkldnn.enabled and torch.backends.mkldnn.is_available()),
    "no MKLDNN",
)


def symbolic_trace_with_rewrite(root: Union[torch.nn.Module, Callable]) -> GraphModule:
    return GraphModule(
        root if isinstance(root, torch.nn.Module) else torch.nn.Module(),
        RewritingTracer().trace(root),
    )


class TestFXExperimental(JitTestCase):
    def test_find_single_partition(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(1)
        b = torch.rand(1)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device("dev_0", 125, 0),
            Device("dev_1", 150, 1),
            Device("dev_2", 125, 2),
        ]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        assert dag.nodes[0].logical_device_ids == [1]

    def test_lack_of_devices(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [Device("dev_0", 4, 0), Device("dev_1", 4, 1)]
        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
        catch_runtime_error = False
        try:
            ret = partitioner.partition_graph(traced, m, partitioner_config)
        except RuntimeError:
            catch_runtime_error = True
        assert catch_runtime_error

    def test_large_node_error(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                linear = self.linear(a)
                add = linear + a
                return add

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [
            Device("dev_0", 40, 0),
            Device("dev_1", 40, 0),
            Device("dev_2", 40, 0),
            Device("dev_3", 40, 0),
            Device("dev_4", 40, 0),
        ]
        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
        catch_runtime_error = False
        try:
            ret = partitioner.partition_graph(traced, m, partitioner_config)
        except RuntimeError:
            catch_runtime_error = True
        assert catch_runtime_error

    def test_partition_node_manipulation(self):
        class TestModule(torch.nn.Module):
            def forward(self, a, b):
                add_1 = a + b
                add_2 = add_1 + torch.rand(4)
                add_3 = add_2 + torch.rand(4)
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a, b = torch.rand(4), torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [Device("dev_0", 1000, 0)]
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        partition = partitioner.partitions[0]
        assert partition.used_mem_bytes == 112
        # Select add_2 node to remove
        selected_node = None
        for node in partition.nodes:
            if node.name == "add_2":
                selected_node = node
        partition.remove_node(selected_node)
        assert partition.used_mem_bytes == 80

    def test_size_based_partition(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)
                self.c = torch.rand(4)

            def forward(self, a, b):
                add_1 = a + b
                linear = self.linear(add_1)
                add_2 = linear + self.c
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        b = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b])
        partitioner = Partitioner()
        devices = [
            Device("dev_0", 125, 0),
            Device("dev_1", 125, 1),
            Device("dev_2", 125, 2),
        ]
        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b), module_with_submodules(a, b))
        for i, node in enumerate(dag.nodes):
            assert node.logical_device_ids == [i]

    def test_partition_device_mapping(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                b = torch.rand(4)
                add_1 = a + b
                linear_1 = self.linear(add_1)
                add_2 = torch.rand(4) + a
                add_3 = add_2 + linear_1
                return add_3

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        partitioner = Partitioner()
        devices = [Device("dev_0", 120, 0), Device("dev_1", 160, 1)]
        partitioner_config = PartitionerConfig(devices, PartitionMode.size_based)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        for i, node in enumerate(dag.nodes):
            if i == 1:
                assert node.logical_device_ids == [1]
            else:
                assert node.logical_device_ids == [0]

    def test_sparse_nn_partition(self):
        class MyRecommendationModule(torch.nn.Module):
            def create_mlp(self, num_of_layers: int, input_size: int, output_size: int):
                layers = torch.nn.ModuleList()
                for _ in range(num_of_layers):
                    ll = torch.nn.Linear(input_size, output_size)
                    layers.append(ll)
                    layers.append(torch.nn.ReLU())
                return layers

            def __init__(self) -> None:
                super().__init__()
                layers = self.create_mlp(4, 4, 4)
                self.bottom_layers = torch.nn.Sequential(*layers)
                layers = self.create_mlp(3, 24, 24)
                self.top_layers = torch.nn.Sequential(*layers)
                self.embedding_layers = torch.nn.ModuleList()
                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
                self.embedding_layers.append(el)
                for i in range(3):
                    el = torch.nn.EmbeddingBag(1000000, 4, mode="sum", sparse=True)
                    self.embedding_layers.append(el)
                el = torch.nn.EmbeddingBag(500000, 4, mode="sum", sparse=True)
                self.embedding_layers.append(el)

            def forward(self, a, b, offset):
                x = self.bottom_layers(a)
                y = []
                c = []
                for i in range(len(self.embedding_layers)):
                    temp = torch.randint(10, (8,))
                    c.append(temp + b)
                for i in range(len(self.embedding_layers)):
                    if i % 2 == 0:
                        y.append(self.embedding_layers[i](c[i], offset))
                    else:
                        y.append(
                            self.embedding_layers[i](torch.randint(10, (8,)), offset)
                        )
                z = torch.cat([x] + y, dim=1)
                p = self.top_layers(z)
                return p

        m = MyRecommendationModule()
        a = torch.rand(2, 4)
        b = torch.randint(10, (8,))
        offset = torch.randint(1, (2,))
        traced = symbolic_trace(m)
        graph_manipulation.get_size_of_all_nodes(traced, [a, b, offset])
        devices = [
            Device("dev_0", 33000000, 0),
            Device("dev_1", 33000000, 1),
            Device("dev_2", 33000000, 2),
        ]
        partitioner_config = PartitionerConfig(devices, PartitionMode.sparse_nn)
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a, b, offset), module_with_submodules(a, b, offset))
        assert len(module_with_submodules.graph.nodes) == 24

    def test_partition_latency(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + linear_1
                add_4 = add_2 + add_3
                return add_4

        def get_node_to_latency_mapping(fx_module: GraphModule):
            """Given a fx module, generate node latency for each node
            based on the size of each node
            """
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {"output", "placeholder", "get_attr"}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 2.0 * node.size_bytes.total_size
                        )
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, node.size_bytes.output_size
                        )
            return node_to_latency_mapping

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        devices = [Device("dev_0", 200, 0), Device("dev_1", 200, 1)]
        partitioner = Partitioner()
        partitioner_config = PartitionerConfig(devices)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping
        )
        for p in partition_to_latency_mapping:
            if p.partition_id == 0:
                assert partition_to_latency_mapping[p] == (128.0, 80.0, 160.0)
            else:
                assert partition_to_latency_mapping[p] == (16.0, 32.0, 32.0)
        transfer_rate_bytes_per_sec = 2
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
        )
        assert critical_path_latency_sec == 208.0

    def test_cost_aware_partition(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + torch.rand(4)
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        def get_node_to_latency_mapping(fx_module: GraphModule):
            node_to_latency_mapping: Dict[Node, NodeLatency] = {}
            for node in fx_module.graph.nodes:
                if node.op not in {"output", "placeholder", "get_attr"}:
                    if node.size_bytes.total_size == node.size_bytes.output_size:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, 1
                        )
                    else:
                        node_to_latency_mapping[node] = NodeLatency(
                            node.size_bytes.total_size, node.size_bytes.output_size
                        )
            return node_to_latency_mapping

        m = MyModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device("dev_0", 125, 0),
            Device("dev_1", 125, 1),
            Device("dev_2", 125, 2),
            Device("dev_3", 125, 3),
        ]
        node_to_latency_mapping = get_node_to_latency_mapping(traced)
        partitioner_config = PartitionerConfig(
            devices,
            mode=PartitionMode.cost_aware,
            transfer_rate_bytes_per_sec=2,
            node_to_latency_mapping=node_to_latency_mapping,
        )
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(traced(a), module_with_submodules(a))
        partitions = partitioner.partitions
        partition_to_latency_mapping = get_partition_to_latency_mapping(
            partitions, node_to_latency_mapping
        )
        critical_path_latency_sec = get_latency_of_partitioned_graph(
            partitions,
            partition_to_latency_mapping,
            partitioner_config.transfer_rate_bytes_per_sec,
        )
        assert critical_path_latency_sec == 160.0

    def test_aot_based_partition(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = torch.rand(4)
                self.c = torch.rand(4)

            def forward(self, a):
                add_1 = a + self.b
                add_2 = self.c + add_1
                return add_2

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        node_to_partition_id = {}
        partition_to_logical_devices = {}
        count = 0
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        for node in traced.graph.nodes:
            if node.op not in {"placeholder", "get_attr", "output"}:
                node_to_partition_id[node] = count
                partition_to_logical_devices[count] = [0]
                count += 1
        devices = [Device("dev_0", 200, 0)]
        partitioner_config = PartitionerConfig(
            devices=devices,
            mode=PartitionMode.aot_based,
            node_to_partition_mapping=node_to_partition_id,
            partition_to_logical_device_mapping=partition_to_logical_devices,
        )
        partitioner = Partitioner()
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        dag = ret.dag
        self.assertEqual(module_with_submodules(a), traced(a))
        for node in dag.nodes:
            assert node.size_bytes == 48
            assert node.logical_device_ids == [0]

    def test_replace_target_nodes_with(self):
        class testModule(torch.nn.Module):
            def forward(self, a, b):
                return a + b

        m = testModule()
        traced = symbolic_trace(m)
        input1 = torch.randn(1)
        input2 = torch.randn(1)
        assert (input1 + input2) == traced(input1, input2)
        graph_manipulation.replace_target_nodes_with(
            fx_module=traced,
            old_op="call_function",
            old_target=operator.add,
            new_op="call_function",
            new_target=operator.mul,
        )
        assert (input1 * input2) == traced(input1, input2)

    def test_saturate_host(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a):
                add_1 = a + torch.rand(4)
                add_2 = add_1 + torch.rand(4)
                linear_1 = self.linear(add_1)
                add_3 = add_2 + linear_1
                add_4 = add_2 + add_3
                return add_4

        m = TestModule()
        traced = symbolic_trace(m)
        a = torch.rand(4)
        graph_manipulation.get_size_of_all_nodes(traced, [a])
        devices = [
            Device("dev_0", 200, 0),
            Device("dev_1", 200, 1),
            Device("dev_2", 100, 2),
            Device("dev_3", 100, 3),
            Device("dev_4", 200, 4),
            Device("dev_5", 100, 5),
        ]
        partitioner = Partitioner()
        # Without host saturation, the model will be split into two partitions.
        # dev_0 holds partition 0 of 192 bytes and dev_1 holds partition 1 of 48 bytes.
        partitioner_config = PartitionerConfig(devices, saturate_host=True)
        ret = partitioner.partition_graph(traced, m, partitioner_config)
        module_with_submodules = ret.module_with_submodules
        self.assertEqual(traced(a), module_with_submodules(a))

        partitions = partitioner.partitions
        self.assertEqual(len(partitions), 2)
        # With host saturation, partition 1 will be replicated to dev_4, and partition 2
        # will be replicated to dev_2.
        self.assertEqual(partitions[0].logical_device_ids, [0, 4])
        self.assertEqual(partitions[1].logical_device_ids, [1, 2])

    @skipIfNoTorchVision
    def test_conv_bn_fusion(self):
        rn18 = resnet18().eval()
        traced = symbolic_trace(rn18)
        fused = optimization.fuse(traced)

        self.assertTrue(
            all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
        )

        N, C, H, W = 20, 3, 224, 224
        inp = torch.randn(N, C, H, W)

        self.assertEqual(fused(inp), rn18(inp))

    def test_conv_bn_fusion_not_running_state(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(32, 64, 3, stride=2)
                self.bn = torch.nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return x

        model = M().eval()

        traced = symbolic_trace(model)
        fused = optimization.fuse(traced)
        inp = torch.randn([1, 32, 50, 50])

        # bn need not be folded in conv
        self.assertTrue(
            any(isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
        )
        self.assertEqual(fused(inp), model(inp))

    def test_conv_bn_fusion_mixed_dtype(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False, dtype=torch.bfloat16)
                self.bn = torch.nn.BatchNorm2d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)

            def forward(self, x):
                x = self.conv(x)
                x = self.bn(x)
                return x

        model = M().eval()

        traced = symbolic_trace(model)
        fused = optimization.fuse(traced)
        inp = torch.randn(1, 3, 64, 64, dtype=torch.bfloat16)

        self.assertTrue(
            all(not isinstance(m, torch.nn.BatchNorm2d) for m in fused.modules())
        )
        self.assertEqual(fused(inp), model(inp))

    def test_call_to_assert_no_msg(self):
        class M(torch.nn.Module):
            def forward(self, a, b):
                assert a == b
                return a + b

        m = M()
        traced = symbolic_trace_with_rewrite(m)

        # Make sure the graph is well-formed
        traced.graph.lint()

        # Check the IR to make sure there's a call_function node with target == "Assert"
        self.assertTrue(
            any(
                node.op == "call_function" and node.target == torch._assert
                for node in traced.graph.nodes
            )
        )

        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
        traced(3, 3)
        with self.assertRaisesRegex(AssertionError, ""):
            traced(3, 5)

        # Confirm that the output is correct
        self.assertEqual(traced(3, 3), m(3, 3))

    def test_meta_tracer(self):
        class MetaTracerTestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.emb = torch.nn.Embedding(num_embeddings=42, embedding_dim=16)
                self.layernorm = torch.nn.LayerNorm(16)

            def forward(self, x):
                emb = self.emb(x)
                emb = emb + torch.arange(emb.shape[-1], dtype=torch.float, device=emb.device)
                lol = self.layernorm(emb)
                return torch.relu(lol) if lol.shape[0] < 30 else torch.sigmoid(lol)

        mttm = MetaTracerTestModule()
        for BS in [15, 35]:
            x = torch.zeros(BS, dtype=torch.long).random_(42)
            meta_args = {'x' : x.to(device='meta')}
            gm = torch.fx.experimental.meta_tracer.symbolic_trace(mttm, meta_args=meta_args)
            torch.testing.assert_close(gm(x), mttm(x))

            # Test serialization/deserialization
            with tempfile.TemporaryDirectory() as tmp_dir:
                with open(f'{tmp_dir}/meta_module.pkl', 'wb') as f:
                    pickle.dump(gm, f)

                with open(f'{tmp_dir}/meta_module.pkl', 'rb') as f:
                    loaded = pickle.load(f)

                torch.testing.assert_close(loaded(x), mttm(x))


    def test_call_to_assert_with_msg(self):
        class M(torch.nn.Module):
            def forward(self, a, b):
                assert a == b, "test message"
                return a + b

        m = M()
        traced = symbolic_trace_with_rewrite(m)

        # Make sure the graph is well-formed
        traced.graph.lint()

        # Check the IR to make sure there's a call_function node with target == "Assert"
        self.assertTrue(
            any(
                node.op == "call_function" and node.target == torch._assert
                for node in traced.graph.nodes
            )
        )

        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
        traced(3, 3)
        with self.assertRaisesRegex(AssertionError, "test message"):
            traced(3, 5)

        # Confirm that the output is correct
        self.assertEqual(traced(3, 3), m(3, 3))

    def test_call_to_assert_with_empty_msg(self):
        class M(torch.nn.Module):
            def forward(self, a, b):
                assert a == b, ""
                return a + b

        m = M()
        traced = symbolic_trace_with_rewrite(m)

        # Make sure the graph is well-formed
        traced.graph.lint()

        # Check the IR to make sure there's a call_function node with target == "Assert"
        self.assertTrue(
            any(
                node.op == "call_function" and node.target == torch._assert
                for node in traced.graph.nodes
            )
        )

        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
        traced(3, 3)
        with self.assertRaisesRegex(AssertionError, ""):
            traced(3, 5)

        # Confirm that the output is correct
        self.assertEqual(traced(3, 3), m(3, 3))

    def test_call_to_assert_with_multiline_message(self):
        class M(torch.nn.Module):
            def forward(self, a, b):
                error_msg = """
An error message with
terrible spacing
                """
                assert a == b, error_msg
                return a + b

        m = M()
        traced = symbolic_trace_with_rewrite(m)

        # Make sure the graph is well-formed
        traced.graph.lint()

        # Check the IR to make sure there's a call_function node with target == "Assert"
        self.assertTrue(
            any(
                node.op == "call_function" and node.target == torch._assert
                for node in traced.graph.nodes
            )
        )

        # Ensure that the assert throws when it's supposed to and doesn't throw when it's not supposed to
        error_msg = """
An error message with
terrible spacing
    """
        traced(3, 3)
        with self.assertRaisesRegex(AssertionError, error_msg):
            traced(3, 5)

        # Confirm that the output is correct
        self.assertEqual(traced(3, 3), m(3, 3))

    def test_subgraph_creation(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(3, 4))
                self.linear = torch.nn.Linear(4, 5)

            def forward(self, x, y):
                z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
                w = self.linear(y).clamp(min=0.0, max=1.0)
                return z + w

        # symbolically trace model
        my_module = MyModule()
        my_module_traced = symbolic_trace(my_module)

        # random mod partitioning
        partition_counter = 0
        NPARTITIONS = 3

        # Add some random meta info to make sure it is kept around.
        for node in my_module_traced.graph.nodes:
            if node.op != "output":
                node.meta["test_meta_info"] = True

        def mod_partition(node: Node):
            nonlocal partition_counter
            partition = partition_counter % NPARTITIONS
            partition_counter = (partition_counter + 1) % NPARTITIONS
            return partition

        # split module in module with submodules
        module_with_submodules = split_module(
            my_module_traced, my_module, mod_partition
        )

        # Check that test_meta_info was still on all nodes.
        submodules = dict(module_with_submodules.named_modules())
        for node in module_with_submodules.graph.nodes:
            if node.op == "call_module":
                submod = submodules[node.target]
                self.assertTrue(isinstance(submod, torch.fx.GraphModule))
                for submod_node in submod.graph.nodes:
                    if submod_node.op != "output":
                        stored_op = submod_node.meta.get("test_meta_info")
                        self.assertTrue(stored_op is not None and stored_op)

        x = torch.rand(3, 4)
        y = torch.rand(3, 4)

        orig_out = my_module_traced(x, y)
        submodules_out = module_with_submodules(x, y)

        self.assertEqual(orig_out, submodules_out)

    def test_split_module_dead_code(self):
        class ModWithDeadCode(torch.nn.Module):
            def forward(self, x):
                output = x * 2  # we want this
                dead_line = x + 2  # this is dead
                return output

        mod = ModWithDeadCode()
        traced = torch.fx.symbolic_trace(mod)

        # split into before (0), target (1), and after(2)
        saw_mul = False

        def split_callback(n):
            nonlocal saw_mul
            if n.target == operator.mul:
                saw_mul = True
                return 1

            if not saw_mul:
                return 0
            if saw_mul:
                return 2

        split = split_module(traced, mod, split_callback)

        x = torch.randn((5,))
        torch.testing.assert_close(
            split(x), traced(x)
        )


    def test_split_module_kwargs_expansion(self):
        class ModuleWithKwargsExpansion(torch.nn.Module):
            def forward(self, x, **kwargs):
                return x + kwargs['foo']

        mod = ModuleWithKwargsExpansion()
        traced = torch.fx.symbolic_trace(mod)

        seen_getitem = False

        def split_callback(n):
            nonlocal seen_getitem
            split_idx = int(seen_getitem)
            if n.target == operator.getitem:
                seen_getitem = True
            return split_idx

        split = split_module(traced, mod, split_callback)

        x = torch.randn(5, 3)
        foo = torch.randn(5, 3)
        torch.testing.assert_close(split(x, foo=foo), traced(x, foo=foo))

    @skipIfNoTorchVision
    def test_subgraph_trivial_resnet(self):
        # Smoke test trivially splitting resnet into 1 partition works
        # There was an issue before causing submodule names to be aliased
        m = resnet18()
        traced = symbolic_trace(m)
        a = torch.rand(64, 3, 7, 7)
        module_with_submodules = split_module(traced, m, lambda node: 0)
        module_with_submodules(a)

    def test_split_module_default_arg(self):
        class ModelToTrace(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin = torch.nn.Linear(512, 512)

            def forward(self, x, targets=None):
                x = self.lin(x)

                if targets is not None:
                    x = x + targets

                return x

        mtt = ModelToTrace()
        traced = torch.fx.symbolic_trace(mtt, concrete_args={'targets': None})

        split = split_module(traced, mtt, lambda node: 0)

        x = torch.randn(50, 512)
        torch.testing.assert_close(split(x), traced(x))

    def test_normalize_binary_operators(self):
        ops_to_test = {
            torch.add,
            torch.mul,
            torch.sub,
            torch.div,
            torch.floor_divide,
            torch.remainder,
            torch.eq,
            torch.ne,
            torch.lt,
            torch.le,
            torch.gt,
            torch.ge,
        }

        # Test Tensor/Tensor callsite
        for op in ops_to_test:

            class WrapperMod(torch.nn.Module):
                def forward(self, x, y):
                    return op(x, y)

            traced = symbolic_trace(WrapperMod())
            normalized = NormalizeOperators(traced).transform()
            x, y = torch.randn(3, 4), torch.randn(3, 4)
            torch.testing.assert_close(traced(x, y), normalized(x, y))
            self.assertFalse(
                any(n.target in ops_to_test for n in normalized.graph.nodes)
            )

        # Test Tensor/scalar callsite
        for op in ops_to_test:

            class WrapperMod(torch.nn.Module):
                def forward(self, x):
                    return op(x, 42)

            traced = symbolic_trace(WrapperMod())
            normalized = NormalizeOperators(traced).transform()
            x = torch.randn(3, 4)
            torch.testing.assert_close(traced(x), normalized(x))
            self.assertFalse(
                any(n.target in ops_to_test for n in normalized.graph.nodes)
            )

    @skipIfNoTorchVision
    def test_normalize_args(self):
        m = resnet18()

        class FunctionalTracer(torch.fx.Tracer):
            def is_leaf_module(
                self, m: torch.nn.Module, module_qualified_name: str
            ) -> bool:
                # `leaves` contains the set of standard `nn.Modules` that are not
                # currently symbolically traceable. Ideally this set would be empty
                leaves = {torch.nn.BatchNorm2d}
                return type(m) in leaves

        traced = torch.fx.GraphModule(m, FunctionalTracer().trace(m))

        input = torch.randn(5, 3, 224, 224)
        ref_outs = traced(input)

        ShapeProp(traced).propagate(input)
        traced = NormalizeArgs(traced).transform()

        modules = dict(traced.named_modules())

        for node in traced.graph.nodes:
            if node.op == "call_function" and node.target != operator.add:
                self.assertEqual(len(node.args), 0)
            elif node.op == "call_module":
                submod_class = modules[node.target].__class__
                nn_class = getattr(torch.nn, submod_class.__name__)
                if submod_class == nn_class:
                    self.assertEqual(len(node.args), 0)
        traced(input)
        self.assertEqual(traced(input), ref_outs)

    def test_normalize_modules_exhaustive(self):
        """
        Exhaustively test `Node.normalized_arguments` on all standard
        torch.nn Module classes
        """
        for test_params in module_tests + new_module_tests:
            if "constructor" not in test_params:
                constructor = getattr(torch.nn, test_params["module_name"])
            else:
                constructor = test_params["constructor"]

            if "constructor_args" not in test_params:
                args = ()
            else:
                args = test_params["constructor_args"]

            mod = constructor(*args)
            # Skip modules that are not standard `torch.nn`
            # instances, including functionals. (functionals
            # are tested in test_normalize_args)
            if mod.__class__.__name__ not in dir(torch.nn):
                continue

            if "input_fn" not in test_params:
                inputs = torch.randn(test_params["input_size"])
            else:
                inputs = test_params["input_fn"]()

            if not isinstance(inputs, (tuple, list)):
                inputs = (inputs,)

            params = ", ".join(f"v{i}" for i in range(len(inputs)))

            # Generate a class to wrap this standard `nn.Module` instance
            test_classname = f"Test{mod.__class__.__name__}"
            test_mod_code = f"""
class {test_classname}(torch.nn.Module):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod

    def forward(self, {params}):
        return self.mod({params})
            """

            gbls = {"torch": torch}
            exec(test_mod_code, gbls)

            test_instance = gbls[test_classname](mod)
            traced = symbolic_trace(test_instance)

            # Use `Node.normalized_arguments` to get a new set of arguments
            # to feed to the Module. Then, rewrite the node to only take
            # in those arguments as kwargs
            modules = dict(traced.named_modules())
            for node in traced.graph.nodes:
                if node.op == "call_module":
                    submod_class = modules[node.target].__class__
                    nn_class = getattr(torch.nn, submod_class.__name__)
                    if submod_class == nn_class:
                        normalized_args = node.normalized_arguments(traced)
                        normalized_args2 = normalize_module(
                            traced, node.target, node.args, node.kwargs
                        )
                        assert normalized_args == normalized_args2
                        assert normalized_args
                        node.args = normalized_args.args
                        node.kwargs = normalized_args.kwargs

            traced.recompile()

            # These Modules have an RNG in their forward, so testing
            # correctness by comparing outputs is not correct. Skip that
            # check for these
            stochastic_modules = {"FractionalMaxPool2d", "FractionalMaxPool3d", "RReLU"}

            if mod.__class__.__name__ not in stochastic_modules:
                self.assertEqual(traced(*inputs), mod(*inputs))

            traced = NormalizeArgs(symbolic_trace(test_instance)).transform()
            modules = dict(traced.named_modules())
            for node in traced.graph.nodes:
                if node.op == "call_module":
                    submod_class = modules[node.target].__class__
                    nn_class = getattr(torch.nn, submod_class.__name__)
                    if submod_class == nn_class:
                        self.assertEqual(len(node.args), 0)

    def test_normalize_args_preserve_meta(self):
        class MyModule(torch.nn.Module):
            def forward(self, a):
                return torch.add(a, 3)

        m = MyModule()
        traced = symbolic_trace(m)

        for node in traced.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                node.meta["my_key"] = 7
                break
        else:
            self.fail("Didn't find call_function torch.add")

        input = torch.randn(2, 3)
        ShapeProp(traced).propagate(input)
        traced = NormalizeArgs(traced).transform()

        for node in traced.graph.nodes:
            if node.op == "call_function" and node.target == torch.add:
                self.assertTrue("my_key" in node.meta)
                self.assertEqual(node.meta["my_key"], 7)
                break
        else:
            self.fail("Didn't find call_function torch.add")

    def test_normalize_args_perserve_type(self):
        class MyModule(torch.nn.Module):
            def forward(self, a: List[torch.Tensor]):
                return torch.add(a[0], a[1])

        m = MyModule()
        traced = symbolic_trace(m)
        traced = NormalizeArgs(traced).transform()

        for node in traced.graph.nodes:
            if node.op == "placeholder":
                self.assertEqual(node.type, List[torch.Tensor])

    @skipIfNoTorchVision
    def test_annotate_returns_with_schema(self):
        m = resnet18()

        traced_modules = symbolic_trace(m)
        traced_modules_annotated = AnnotateTypesWithSchema(traced_modules).transform()
        for node in traced_modules_annotated.graph.nodes:
            if node.type is None:
                check = (node.op, node.target)
                self.assertIn(
                    check,
                    {
                        ("placeholder", "x"),
                        ("call_module", "maxpool"),
                        ("call_function", operator.add),
                        ("call_function", torch.flatten),
                        ("output", "output"),
                    }
                )

        # Smoke test torchscript compilation since now we're emitting type annotations
        torch.jit.script(traced_modules_annotated)

        class FunctionalTracer(torch.fx.Tracer):
            def is_leaf_module(
                self, m: torch.nn.Module, module_qualified_name: str
            ) -> bool:
                # `leaves` contains the set of standard `nn.Modules` that are not
                # currently symbolically traceable. Ideally this set would be empty
                leaves = {torch.nn.BatchNorm2d}
                return type(m) in leaves

        traced_functionals = torch.fx.GraphModule(m, FunctionalTracer().trace(m))

        traced_functionals_annotated = AnnotateTypesWithSchema(
            traced_functionals
        ).transform()
        for node in traced_functionals_annotated.graph.nodes:
            if node.type is None:
                check = (node.op, node.target)
                excluded_nodes = {
                    ("placeholder", "x"),
                    # Return type differs based on boolean dispatch :(
                    ("call_function", torch.nn.functional.max_pool2d),
                    ("output", "output"),
                }
                # AnnotateTypesWithSchema doesn't work with bound C++ functions
                if not isinstance(node.target, BuiltinFunctionType):
                    self.assertIn(check, excluded_nodes)

        # Smoke test torchscript compilation since now we're emitting type annotations
        torch.jit.script(traced_functionals_annotated)

    def test_annotate_getitem_node(self):
        class CustomType:
            pass

        class CustomNamedTuple(NamedTuple):
            x: int
            y: float

        class MyModule(torch.nn.Module):
            def forward(self, inp: Tuple[CustomType, torch.Tensor], inp2: List[CustomType], inp3: CustomNamedTuple):
                inp_0 = inp[0]
                inp_1 = inp[1]
                inp2_0 = inp2[0]
                inp3_x = inp3.x
                inp3_y = inp3.y
                return inp_0 + inp_1 + inp2_0 + inp3_x + inp3_y

        my_module = MyModule()
        my_module_traced = torch.fx.symbolic_trace(my_module)

        # by default, fx transform loses type annotation of getitem nodes.
        for node in my_module_traced.graph.nodes:
            if node.target == operator.getitem:
                assert node.type is None

        annotate_getitem_nodes(my_module_traced.graph)

        for node in my_module_traced.graph.nodes:
            if node.target == operator.getitem:
                self.assertIsNotNone(node.type, f"Node {node} should be annotated but is not.")

    def test_subgraph_uniquename(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(4, 4)

            def forward(self, a, b, c, d):
                add_1 = a + b
                add_2 = add_1 + c
                linear_1 = self.linear(add_1)
                add_3 = add_2 + d
                add_4 = add_2 + linear_1
                add_5 = add_3 + add_4
                return add_5

        a, b, c, d = torch.ones(4), torch.ones(4), torch.ones(4), torch.ones(4)
        mm = MyModule()
        traced = symbolic_trace(mm)

        def split_cb(node: torch.fx.Node):
            if node.name == "a" or node.name == "b" or node.name == "add":
                return 0
            else:
                return 1

        module_with_submodule = split_module(traced, mm, split_cb)
        self.assertEqual(module_with_submodule(a, b, c, d), traced(a, b, c, d))

    def test_split_qualname_mapping(self):
        d_hid = 4

        class ExampleCode(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mm_param = torch.nn.Parameter(torch.randn(d_hid, d_hid))
                self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid))
                self.lin = torch.nn.Linear(d_hid, d_hid)

            def forward(self, x):
                x = torch.mm(x, self.mm_param)
                x = torch.relu(x)
                x = torch.mm(x, self.mm_param)
                x = self.lin(x)
                x = torch.relu(x)
                x = torch.mm(x, self.mm_param2)
                x = self.lin(x)
                return x

        my_module = ExampleCode()
        my_module_traced = symbolic_trace(my_module)

        part_idx = 0

        def split_callback(n : torch.fx.Node):
            nonlocal part_idx
            if (n.op, n.target) == ('call_module', 'lin'):
                part_idx += 1
            return part_idx

        # split module in module with submodules
        qualname_map : Dict[str, str] = {}
        module_with_submodules = split_module(
            my_module_traced, my_module, split_callback, qualname_map
        )
        expected_qualname_map = {
            'submod_1.lin': 'lin', 'submod_2.lin': 'lin'
        }
        self.assertEqual(qualname_map, expected_qualname_map)

    def test_traceable_function_with_nonstandard_name(self):
        def foo(x):
            return torch.relu(x)

        traced = symbolic_trace_with_rewrite(foo)

    def test_to_folder(self):
        class Test(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.W = torch.nn.Parameter(torch.randn(2))
                self.seq = torch.nn.Sequential(torch.nn.BatchNorm1d(2, 2))
                self.linear = torch.nn.Linear(2, 2)
                self.attr = torch.randn(2)
                self.attr2 = torch.nn.Buffer(torch.randn(2))
                self.attr3 = torch.nn.Buffer(torch.ones(2, dtype=torch.int32))

            def forward(self, x):
                return self.linear(self.seq(self.W + self.attr + self.attr2 + self.attr3 + x))

        mod = symbolic_trace(Test())
        module_name = "Foo"
        import tempfile
        from pathlib import Path

        with tempfile.TemporaryDirectory() as tmp_dir:
            tmp_dir = Path(tmp_dir)
            mod.to_folder(tmp_dir, module_name)
            # Recipe taken from here:
            # https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly
            import importlib.util

            spec = importlib.util.spec_from_file_location(
                module_name, tmp_dir / "__init__.py"
            )
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)
            t = torch.randn(2, 2)
            self.assertEqual(module.Foo()(t), mod(t))

    def test_fetch(self):
        attrs_for_lowering: Dict[str, List[str]] = {
            "torch.nn.modules.conv.Conv2d": [
                "weight",
                "bias",
                "kernel_size",
                "stride",
                "padding",
                "dilation",
                "groups",
                "padding_mode",
            ],
            "torch.nn.modules.batchnorm.BatchNorm2d": [
                "weight",
                "bias",
                "running_mean",
                "running_var",
                "eps",
            ],
        }

        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 2)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, a):
                a = self.conv(a)
                a += a
                return self.bn(a)

        mod = TestModule()
        traced = symbolic_trace(mod)
        lift_lowering_attrs_to_nodes(traced)

        for node in traced.graph.nodes:
            if node.op == "call_module":
                assert hasattr(node, "attrs_for_lowering")
                para_list = attrs_for_lowering[node.attrs_for_lowering["name"]]

                # node.attrs_for_lowering has an addition field of class name
                assert len(para_list) + 1 == len(node.attrs_for_lowering)
                for p_name in para_list:
                    assert p_name in node.attrs_for_lowering

    def test_merge_matmuls(self):
        """
        A collection of test cases for torch.fx.experimental.merge_matmul,
        a graph transformation that merges matrix multiplication operations.
        """
        # Utility function for counting matmuls for test assertions.
        def _count_matmuls(mod):
            gm = torch.fx.symbolic_trace(mod)

            num_matmuls = 0
            for node in gm.graph.nodes:
                if node.target == torch.matmul:
                    num_matmuls += 1

            return num_matmuls

        # Simple test case in which there are two matmuls of the same size to merge.
        class SimpleMergeMatmulModule(torch.nn.Module):
            def __init__(self, rhs):
                super().__init__()
                self.rhs = rhs

            def forward(self, x, y):
                a = torch.matmul(x, self.rhs)
                b = torch.matmul(y, self.rhs)
                return a + b

        # Initialize inputs.
        a = torch.randn(3, 3)
        b = torch.randn(3, 3)

        # Initialize RHS for matmuls.
        rhs = torch.randn(3, 4)

        # Construct SimpleMergeMatmulModule and call merge_matmul on it.
        module = SimpleMergeMatmulModule(rhs)
        opt_module = merge_matmul.merge_matmul(module)

        # Numerical correctness check.
        before = module(a, b)
        after = opt_module(a, b)
        before.allclose(after)

        # Basic graph structure check; original module should have 2 matmuls
        # and optimized module should have 1.
        self.assertEqual(_count_matmuls(module), 2)
        self.assertEqual(_count_matmuls(opt_module), 1)

        # Test case in which there are multiple matmuls of different sizes to merge.
        class FiveMergeMatmulModule(torch.nn.Module):
            def __init__(self, rhs):
                super().__init__()
                self.rhs = rhs

            def forward(self, a, b, c, d, e):
                s = torch.tensor([])
                matmuls = []

                # For some reason using a list comprehension or for-loop for this
                # doesn't work.
                matmuls.append(torch.matmul(a, self.rhs))
                matmuls.append(torch.matmul(b, self.rhs))
                matmuls.append(torch.matmul(c, self.rhs))
                matmuls.append(torch.matmul(d, self.rhs))
                matmuls.append(torch.matmul(e, self.rhs))

                for m in matmuls:
                    s += torch.sum(m)

                return s

        # Initialize inputs.
        inputs = [torch.randn(2 * i + 1, 5) for i in range(5)]

        # Initialize RHS.
        rhs = torch.randn(5, 4)

        # Construct FiveMergeMatmulModule and call merge_matmul on it.
        module = FiveMergeMatmulModule(rhs)
        opt_module = merge_matmul.merge_matmul(module)

        # Numerical correctness check.
        before = module(*inputs)
        after = opt_module(*inputs)
        before.allclose(after)

        # Basic graph structure check; original module should have len(inputs) matmuls
        # and optimized module should have 1.
        self.assertEqual(_count_matmuls(module), len(inputs))
        self.assertEqual(_count_matmuls(opt_module), 1)

        # Simple test case in which two matmuls cannot be merged due to a data dependency between
        # the LHS operands.
        class UnmergeableMatmulModule(torch.nn.Module):
            def __init__(self, rhs):
                super().__init__()
                self.rhs = rhs

            def forward(self, x):
                a = torch.matmul(x, self.rhs)
                a_abs = torch.abs(a)
                b = torch.matmul(a_abs.transpose(1, 0), self.rhs)
                return b

        # Initialize inputs.
        a = torch.randn(3, 3)

        # Initialize RHS for matmuls.
        rhs = torch.randn(3, 4)

        # Construct UnmergeableMatmulModule and call merge_matmul on it.
        module = UnmergeableMatmulModule(rhs)
        opt_module = merge_matmul.merge_matmul(module)

        # Numerical correctness check.
        before = module(a)
        after = opt_module(a)
        before.allclose(after)

        # Basic graph structure check; the number of matrix multiplcations should not have changed.
        self.assertEqual(_count_matmuls(module), 2)
        self.assertEqual(_count_matmuls(opt_module), 2)

    def test_type_matches(self):
        should_be_equal = [
            (int, int),
            (numbers.Number, int),
            (numbers.Number, float),
            (int, type(torch.float)),
            (Union[int, float], int),
            (Union[int, float], float),
            (List[int], int),
            (List[int], create_type_hint([int, int])),
            (List[int], create_type_hint((int, int))),
            (List[torch.Tensor], create_type_hint([torch.Tensor, torch.Tensor])),
            (
                List[torch.Tensor],
                create_type_hint([torch.nn.Parameter, torch.nn.Parameter]),
            ),
            (torch.Tensor, torch.nn.Parameter),
            (List[torch.Tensor], create_type_hint([torch.nn.Parameter, torch.Tensor])),
            (List[torch.Tensor], create_type_hint([torch.Tensor, torch.nn.Parameter])),
            (List[torch.Tensor], create_type_hint((torch.Tensor, torch.Tensor))),
            (
                List[torch.Tensor],
                create_type_hint((torch.nn.Parameter, torch.nn.Parameter)),
            ),
            (torch.Tensor, torch.nn.Parameter),
            (List[torch.Tensor], create_type_hint((torch.nn.Parameter, torch.Tensor))),
            (List[torch.Tensor], create_type_hint((torch.Tensor, torch.nn.Parameter))),
            (Optional[List[torch.Tensor]], List[torch.Tensor]),
            (Optional[List[int]], List[int]),
        ]
        for sig_type, arg_type in should_be_equal:
            self.assertTrue(type_matches(sig_type, arg_type))

        should_fail = [
            (int, float),
            (Union[int, float], str),
            (List[torch.Tensor], List[int]),
        ]

        for sig_type, arg_type in should_fail:
            self.assertFalse(type_matches(sig_type, arg_type))

    @skipIfNoMkldnn
    def test_optimize_for_inference_cpu(self):
        import torch.nn as nn

        class Foo(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                layers = []
                layers2 = []
                for _ in range(10):
                    layers.append(nn.Conv2d(3, 3, 1))
                    layers.append(nn.BatchNorm2d(3))
                    layers.append(nn.ReLU())

                    layers2.append(nn.Conv2d(3, 3, 1))
                    layers2.append(nn.BatchNorm2d(3))
                    layers2.append(nn.ReLU())
                self.model = nn.Sequential(*layers)
                self.model2 = nn.Sequential(*layers2)

            def forward(self, x):
                return self.model(x) + self.model2(x)

        N, C, H, W, = (
            1,
            3,
            224,
            224,
        )
        inp = torch.randn(N, C, H, W)
        with torch.no_grad():
            model = Foo().eval()
            optimized_model = optimization.optimize_for_inference(model)
            torch.testing.assert_close(model(inp), optimized_model(inp))

            optimized_model2 = optimization.optimize_for_inference(
                model, pass_config={"remove_dropout": False}
            )
            torch.testing.assert_close(model(inp), optimized_model2(inp))

    @skipIfNoTorchVision
    @skipIfNoMkldnn
    def test_optimize_for_inference_cpu_torchvision(self):
        models = [
            torchvision.models.resnet18,
            torchvision.models.resnet50,
            torchvision.models.densenet121,
            torchvision.models.shufflenet_v2_x1_0,
            torchvision.models.vgg16,
            torchvision.models.mobilenet_v2,
            torchvision.models.mnasnet1_0,
            torchvision.models.resnext50_32x4d,
        ]
        with torch.no_grad():
            for model_type in models:
                model = model_type()
                C, H, W, = (
                    3,
                    224,
                    224,
                )
                inp = torch.randn(3, C, H, W)
                model(inp)
                model.eval()
                inp = torch.randn(1, C, H, W)
                heuristic = optimization.gen_mkl_autotuner(inp, iters=0, warmup=0)
                optimized_model = optimization.optimize_for_inference(model)

                orig_out = model(inp)
                new_out = optimized_model(inp)
                torch.testing.assert_close(orig_out, new_out)


class TestNormalizeOperators(JitTestCase):
    @onlyCPU
    @ops(op_db, allowed_dtypes=(torch.float,))
    def test_normalize_operator_exhaustive(self, device, dtype, op):
        # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
        fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"}
        sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
        if isinstance(op.op, torch._ops.OpOverload):
            self.skipTest("normalize operator doesn't work on torch.ops")
        for sample_input in sample_inputs_itr:
            unsupported_arg_type = False
            arg_values = [sample_input.input] + list(sample_input.args)
            kwarg_values = sample_input.kwargs
            arg_types = []
            kwarg_types = {}

            def jit_infer_type(v):
                inferred_arg_type = torch._C._jit_try_infer_type(v)
                assert inferred_arg_type.success()
                t = _torchscript_type_to_python_type(inferred_arg_type.type())
                return t

            for v in arg_values:
                if isinstance(v, torch.Tensor):
                    arg_types.append(type(v))
                else:
                    if isinstance(v, complex):
                        # Complex type not supported in FX
                        unsupported_arg_type = True
                    arg_types.append(jit_infer_type(v))

            for k, v in kwarg_values.items():
                if isinstance(v, torch.Tensor):
                    kwarg_types[k] = type(v)
                else:
                    if isinstance(v, complex):
                        # Complex type not supported in FX
                        unsupported_arg_type = True
                    kwarg_types[k] = jit_infer_type(v)

            if unsupported_arg_type:
                continue
            # Test normalize_function by itself
            ref_out = op.op(*arg_values, **kwarg_values)
            norm_args_and_kwargs = normalize_function(
                op.op, arg_values, kwarg_values, arg_types, kwarg_types
            )
            if norm_args_and_kwargs is None:
                raise RuntimeError(
                    """
                    FX failed to normalize op - add the op to the op_skip list.
                    A common reason is if your OpInfo was implemented with a lambda
                    - otherwise, file an issue
                    """
                )
            test_out = op.op(*norm_args_and_kwargs.args, **norm_args_and_kwargs.kwargs)
            self.assertEqual(test_out, ref_out)

            # Test normalized_arguments as part of FX
            if op.name in fx_fail:
                continue
            param_names = []
            param_values = []
            fx_args = []

            idx = 0

            def process_arg(arg, name):
                if isinstance(arg, torch.Tensor):
                    param_names.append(name)
                    param_values.append(arg)
                    return name
                else:
                    return f"{repr(arg)}"

            def process_arg_with_idx(arg):
                nonlocal idx
                res = process_arg(arg, f"arg_{idx}")
                idx = idx + 1
                return res

            def str_arg(arg):
                if isinstance(arg, tuple):
                    args = [f"{str_arg(v)}, " for v in arg]
                    return f"({' '.join(args)})"
                elif isinstance(arg, list):
                    args = [f"{str_arg(v)}" for v in arg]
                    return f"[{', '.join(args)}]"
                else:
                    return arg

            for v in arg_values:
                arg = pytree.tree_map(process_arg_with_idx, v)
                fx_args.append(str_arg(arg))

            for k, v in kwarg_values.items():
                arg = pytree.tree_map(functools.partial(process_arg, name=k), v)
                fx_args.append(f"{k} = {str_arg(arg)}")

            code = f"""
class TestModule(torch.nn.Module):
    def forward(self, {', '.join(param_names)}):
        return torch.{op.name}({', '.join(fx_args)})
            """

            g = {"torch": torch, "inf": math.inf}
            exec(code, g)
            TestModule = g["TestModule"]

            m = TestModule()
            traced = torch.fx.symbolic_trace(m)
            ref_out = traced(*param_values)

            for node in traced.graph.nodes:
                if node.op == "call_function":
                    normalized_args = node.normalized_arguments(
                        traced, arg_types, kwarg_types
                    )
                    assert normalized_args
                    node.args = normalized_args.args
                    node.kwargs = normalized_args.kwargs
            traced.recompile()

            test_out = traced(*param_values)
            self.assertEqual(test_out, ref_out)

    def test_normalize_quantized_eb(self):
        target = torch.ops.quantized.embedding_bag_byte_rowwise_offsets
        args = (
            torch.empty((2, 3), dtype=torch.uint8),
            torch.empty((2,), dtype=torch.int64),
            torch.empty((2,), dtype=torch.int64),
        )
        norm_args_and_kwargs = normalize_function(
            target, args, normalize_to_only_use_kwargs=True
        )
        self.assertTrue(norm_args_and_kwargs is not None)
        self.assertEqual(
            set(norm_args_and_kwargs.kwargs.keys()),
            {
                "weight",
                "indices",
                "offsets",
                "scale_grad_by_freq",
                "mode",
                "pruned_weights",
                "per_sample_weights",
                "compressed_indices_mapping",
                "include_last_offset",
            },
        )
        self.assertEqual(norm_args_and_kwargs.args, ())

    def test_normalize_args_op_overload(self):
        for target in [torch.ops.aten.resize_as_.default, torch.ops.aten.resize_as_]:
            inp1 = torch.rand([1])
            inp2 = torch.rand([4])
            args, kwargs = normalize_function(target, (inp1,), {"the_template": inp2}, normalize_to_only_use_kwargs=True)
            self.assertIs(kwargs["input"], inp1)
            self.assertIs(kwargs["the_template"], inp2)


if TEST_Z3:
    import z3

    import torch._dynamo.config

    from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str
    from torch.utils._sympy.functions import FloorDiv, Mod

    class TestTranslationValidation(TestCase):
        def _prepare_for_translation_validation(self):
            validator = TranslationValidator()

            # SymPy symbols.
            s0, s1, s2 = sympy.symbols("s0 s1 s2", integer=True)

            # Z3 symbols.
            [validator.add_var(s, int) for s in (s0, s1, s2)]
            z0, z1, z2 = (validator.z3var(s) for s in (s0, s1, s2))

            return (s0, s1, s2), (z0, z1, z2), validator

        def test_sympy_to_z3(self):

            (
                (s0, s1, s2),
                (z0, z1, z2),
                validator,
            ) = self._prepare_for_translation_validation()

            test_cases = [
                # Integer constants.
                (sympy.S.Zero, z3.IntVal(0)),
                (sympy.S.One, z3.IntVal(1)),
                (sympy.S.NegativeOne, z3.IntVal(-1)),
                (sympy.Integer(2), z3.IntVal(2)),
                (
                    s0,
                    z0,
                ),
                # Arithmetic operations.
                *[
                    (op(s0, s1), op(z0, z1))
                    for op in (
                        operator.add,
                        operator.mul,
                        operator.pow,
                    )
                ],
                # Logical operations.
                *[
                    (sympy_op(s0, s1), z3_op(z0, z1))
                    for sympy_op, z3_op in (
                        (sympy.Eq, operator.eq),
                        (sympy.Ne, operator.ne),
                        (sympy.Lt, operator.lt),
                        (sympy.Le, operator.le),
                        (sympy.Gt, operator.gt),
                        (sympy.Ge, operator.ge),
                    )
                ],
                # Other operations.
                (
                    s0 - s1,
                    z0 + z3.IntVal(-1) * z1,
                ),
                (
                    s0 / s1,
                    z3.ToReal(z0) * (z1**-1),
                ),
                (FloorDiv(s0, s1), z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1))),
                (Mod(s0, s1), z0 - z3.ToInt(z3.ToReal(z0) / z3.ToReal(z1)) * z1),
                (
                    Mod(s2, (s0 / s1)),
                    z2
                    - z3.ToReal(z3.ToInt(z3.ToReal(z2) / (z3.ToReal(z0) * z1**-1)))
                    * (z3.ToReal(z0) * z1**-1),
                ),
                (
                    Mod(s2, s0**3),
                    z2 - z3.ToReal(z3.ToInt(z3.ToReal(z2) / z0**3)) * z0**3,
                ),
            ]

            toZ3 = SympyToZ3(validator)
            for sympy_expr, z3_expr in test_cases:
                result = toZ3.run(sympy_expr)
                self.assertTrue(
                    z3_expr.eq(result), msg=f"expected: {z3_expr}. Got: {result}"
                )

        def test_sat(self):
            (
                (s0, s1, s2),
                (z0, z1, z2),
                validator,
            ) = self._prepare_for_translation_validation()

            validator.add_source_expr(z0 > 5)
            validator.add_source_expr(z1 / 2 > z0)

            # Solutions for target is a subset of the solutions for the source.
            validator.add_target_expr(s0 > 20)
            validator.add_target_expr(s1 > s0**2)

            validator.validate()

        def test_unsat(self):
            (
                (s0, s1, s2),
                (z0, z1, z2),
                validator,
            ) = self._prepare_for_translation_validation()

            validator.add_source_expr(z0 > 5)
            validator.add_source_expr(z1 / 2 > z0)

            # Solutions for target is NOT a subset of the solutions for the source.
            validator.add_target_expr(s0 > 20)
            # This expression is less restrictive than its counterpart.
            validator.add_target_expr(s1 > s0 + 2)

            with self.assertRaisesRegex(ValidationException, "translation validation failed."):
                validator.validate()

        def test_z3str(self):
            a = z3.Int("a")
            b = z3.Int("b")
            special = z3.Real("this.size()[2]")

            test_cases = [
                (z3.IntVal(42), "42"),
                # Variable.
                (a, "a"),
                # Name with special characters.
                (special, "this.size()[2]"),
                # Renamed function fpplications.
                (a != b, "(!= a b)"),
                (a ** b, "(pow a b)"),
                # Chain of associative operations.
                *[
                    (op(op(a, 5), b), f"({opstr} 5 a b)")
                    for op, opstr in [
                        (operator.add, "+"),
                        (operator.mul, "*")
                    ]
                ],
                # Revert 'Not' conversions.
                (a != b, "(!= a b)"),
                (a < b, "(> b a)"),
                (a > b, "(> a b)"),
                # Ignore 'ToInt' and 'ToReal' functions.
                (z3.ToInt(special) + a, "(+ this.size()[2] a)"),
                (z3.ToReal(a + b), "(+ a b)"),
                # Convert to floor division: 'idiv'.
                (z3.ToInt(z3.ToReal(a) / z3.ToReal(b)), "(idiv a b)"),
            ]

            for expr, expected in test_cases:
                self.assertEqual(z3str(expr), expected)


instantiate_device_type_tests(TestNormalizeOperators, globals())

if __name__ == "__main__":
    run_tests()
