# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
import copy
import os
import tempfile
import unittest
from typing import List, Optional, Tuple

import executorch.exir as exir

# Import passes
import executorch.exir.memory_planning  # noqa
import torch
from executorch.exir import EdgeCompileConfig, EdgeProgramManager, memory, to_edge
from executorch.exir.dialects._ops import bind_pattern_to_op, ops, ops as exir_ops
from executorch.exir.dialects.edge._ops import EdgeOpOverload
from executorch.exir.emit import emit_program
from executorch.exir.graph_module import get_control_flow_submodules
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.passes import (
    dead_code_elimination_pass,
    DebugPass,
    HintBasedSymShapeEvalPass,
    MemoryPlanningPass,
    propagate_dynamic_shape,
    RemoveNoopPass,
    ReplaceSymSizeOpPass,
    ToOutVarPass,
)
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.debug_handle_generator_pass import (
    DebugHandleGeneratorPass,
    generate_missing_debug_handles,
)
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
    insert_write_back_for_buffers_pass,
)

from executorch.exir.passes.memory_format_ops_pass import DimOrderOpsRevertPass
from executorch.exir.passes.normalize_view_copy_base_pass import (
    NormalizeViewCopyBasePass,
)
from executorch.exir.passes.remove_graph_asserts_pass import RemoveGraphAssertsPass
from executorch.exir.passes.remove_mixed_type_operators import RemoveMixedTypeOperators
from executorch.exir.passes.replace_edge_with_backend_pass import EdgeToBackendOpsPass
from executorch.exir.passes.replace_view_copy_with_view_pass import (
    ReplaceViewCopyWithViewPass,
)
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from executorch.exir.passes.sym_to_tensor_pass import SymToTensorPass
from executorch.exir.program._program import lift_constant_tensor_pass
from executorch.exir.schema import TensorShapeDynamism
from executorch.exir.tensor import TensorSpec
from executorch.exir.tests.common import register_additional_test_aten_ops
from executorch.exir.tests.control_flow_models import FTCondDeadCode, FTMapBasic
from executorch.exir.tests.models import MLP, Mul
from functorch.experimental import control_flow

from torch import nn

from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)
from torch.export import export
from torch.export.graph_signature import InputKind, InputSpec, TensorArgument
from torch.fx import GraphModule, subgraph_rewriter
from torch.fx.experimental.proxy_tensor import make_fx
from torch.library import impl, Library
from torch.testing import FileCheck
from torch.utils import _pytree as pytree


# pyre-ignore
def collect_ops(gm: torch.fx.GraphModule):
    """
    Collect all targets for call_function nodes from the graph module recursively.
    """
    ops = set()
    for subgm in gm.modules():
        if not isinstance(subgm, torch.fx.GraphModule):
            continue
        for node in subgm.graph.nodes:
            if node.op == "call_function":
                ops.add(node.target)
    return ops


lib = Library("DO_NOT_USE_TEST_ONLY", "DEF")

lib.define("foo(Tensor self) -> (Tensor, Tensor)")
lib.define("add_relu(Tensor self, Tensor other) -> Tensor")


@impl(lib, "foo", "CompositeExplicitAutograd")
def foo(a: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    return a + 1, None


lib.define(
    "foo.out(Tensor self, *, Tensor(a!) out1, Tensor(b!) out2) -> (Tensor(a!), Tensor(b!))"
)


@impl(lib, "foo.out", "CompositeExplicitAutograd")
def foo_out(
    a: torch.Tensor, out1: torch.Tensor, out2: torch.Tensor
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    return a + 1, None


class TestPasses(unittest.TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        register_additional_test_aten_ops()

    def test_remove_mixed_type_operators(self) -> None:
        class Add(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return (x + y) + x

        add = Add()

        int_tensor = torch.tensor([[1, 2, 3]])
        float_tensor = torch.tensor([[1.0, 2.0, 3.0]])
        edge_prog = to_edge(
            export(
                add,
                (int_tensor, float_tensor),
            )
        )

        new_prog = edge_prog.transform([RemoveMixedTypeOperators()])
        new_graph_module = new_prog.exported_program().graph_module
        self.assertIsNotNone(new_graph_module)

        add_count = 0

        for node in new_graph_module.graph.nodes:
            if (
                node.op == "call_function"
                and node.target == exir_ops.edge.aten.add.Tensor
            ):
                add_count += 1
                node_args = node.args
                for arg in node_args:
                    self.assertEqual(arg.meta["val"].dtype, torch.float)

        self.assertEqual(add_count, 2)

        double_tensor = torch.tensor([[1.0, 2.0, 3.0]])
        double_tensor = double_tensor.to(torch.double)

        double_prog = to_edge(export(add, (int_tensor, double_tensor)))

        double_prog.transform([RemoveMixedTypeOperators()])
        new_graph_module_double = double_prog.exported_program().graph_module
        self.assertIsNotNone(new_graph_module_double)

        add_count_double = 0

        for node in new_graph_module_double.graph.nodes:
            if (
                node.op == "call_function"
                and node.target == exir_ops.edge.aten.add.Tensor
            ):
                add_count_double += 1
                node_args = node.args
                for arg in node_args:
                    self.assertEqual(arg.meta["val"].dtype, torch.double)

        self.assertEqual(add_count_double, 2)

        class Mult(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x * y

        mult = Mult()

        float_tensor_vert = float_tensor.T
        mult_prog = to_edge(
            export(
                mult,
                (int_tensor, float_tensor_vert),
            )
        )

        # graph_module_mult.graph.print_tabular()

        mult_prog = mult_prog.transform([RemoveMixedTypeOperators()])
        new_graph_module_mult = mult_prog.exported_program().graph_module
        self.assertIsNotNone(new_graph_module_mult)

        mult_count = 0

        for node in new_graph_module_mult.graph.nodes:
            if (
                node.op == "call_function"
                and node.target == exir_ops.edge.aten.mul.Tensor
            ):
                mult_count += 1
                node_args = node.args
                for arg in node_args:
                    self.assertEqual(arg.meta["val"].dtype, torch.float)

        self.assertEqual(mult_count, 1)

    def test_remove_noop_pass(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x.to(dtype=torch.float32)

        foo = Foo()

        # Turn off functionalization so that we can get the actual to.dtype op
        edge_prog = to_edge(
            export(
                foo,
                (torch.ones(1, dtype=torch.float32),),
            )
        )
        edge_prog = edge_prog.transform([RemoveNoopPass()])
        self.assertIsNotNone(edge_prog.exported_program().graph_module)
        new_graph_module = edge_prog.exported_program().graph_module
        for node in new_graph_module.graph.nodes:
            if node.op == "call_function":
                self.assertNotEqual(node.target, torch.ops.aten.to.dtype)

    def test_redundant_slice_copy_removal(self) -> None:
        class FooWithNoSlice(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x[:, :, :]

        foo_with_no_slice = FooWithNoSlice()

        class FooWithOneSlice(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x[:1, :, :]

        foo_with_one_slice = FooWithOneSlice()

        class FooWithAllSlices(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x[:1, :2, 2:4]

        foo_with_all_slices = FooWithAllSlices()

        # Turn off functionalization so that we can get the actual to.dtype op
        x = torch.ones((3, 8, 8))
        prog = to_edge(
            export(
                foo_with_no_slice,
                (x,),
            )
        )
        prog = prog.transform([RemoveNoopPass()])
        new_graph_module = prog.exported_program().graph_module
        FileCheck().check_count(
            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 0, exactly=True
        ).run(new_graph_module.code)

        prog = to_edge(
            export(
                foo_with_one_slice,
                (x,),
            )
        )
        prog = prog.transform([RemoveNoopPass()])
        new_graph_module = prog.exported_program().graph_module
        FileCheck().check_count(
            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 1, exactly=True
        ).run(new_graph_module.code)

        prog = to_edge(
            export(
                foo_with_all_slices,
                (x,),
            )
        )
        prog = prog.transform([RemoveNoopPass()])
        new_graph_module = prog.exported_program().graph_module
        FileCheck().check_count(
            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor", 3, exactly=True
        ).run(new_graph_module.code)

    def test_compile_to_edge(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x * 2

        f = Foo()

        x = (torch.randn(2, 3),)

        to_edge(
            export(
                f,
                x,
            )
        ).exported_program().graph_module
        # TODO(angelayi): Add a utility function that verifies a model is in
        # the edge dialect

    def test_to_out_variant_none_output(self) -> None:
        class CompositeModel(torch.nn.Module):
            def __init__(self, _weight):
                super().__init__()
                self.weight = _weight
                self.lstm = torch.nn.LSTM(
                    input_size=32,
                    hidden_size=32,
                    num_layers=1,
                )

            def forward(self, x_raw, h, c):
                output, (hn, cn) = self.lstm(x_raw, (h, c))
                return output

        # Prepare input and trace it
        input_x = torch.ones([1, 32])
        input_h = torch.ones([1, 32])
        input_c = torch.ones([1, 32])
        inputs = (input_x, input_h, input_c)

        composite_m = CompositeModel(3)

        edge_prog = to_edge(
            export(
                composite_m,
                inputs,
            )
            # torch._ops.aten.t.default
            ,
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )

        new_prog = edge_prog.transform([SpecPropPass()])

        new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
        self.assertIsNotNone(new_gm_res)
        new_gm = new_gm_res.graph_module
        for node in new_gm.graph.nodes:
            if node.op == "call_function" and node.target in [
                torch.ops.DO_NOT_USE_TEST_ONLY.foo.out,
                torch.ops.my_awesome_3rdparty_ns.awesome_op.out,
            ]:
                self.assertEqual(len(node.kwargs), 2)
                out1_node = node.kwargs["out1"]
                self.assertEqual(out1_node.op, "call_function")
                self.assertIs(out1_node.target, memory.alloc)
                self.assertIs(node.kwargs["out2"], None)

        new_gm_res = MemoryPlanningPass()(new_gm)
        self.assertIsNotNone(new_gm_res)
        new_gm = new_gm_res.graph_module
        new_prog.exported_program().graph_module.graph = new_gm.graph
        emit_program(new_prog.exported_program())

    def test_to_out_variant_singleon_tensor_list(self) -> None:
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.split(x, 10)

            def get_random_inputs(self):
                return (torch.randn(10),)

        model = MyModel()
        inputs = model.get_random_inputs()
        prog = to_edge(
            export(
                model,
                inputs,
            ),
            compile_config=EdgeCompileConfig(_check_ir_validity=False),
        )  # TODO(larryliu): fix split_copy
        new_gm_res = ToOutVarPass()(prog.exported_program().graph_module)
        self.assertIsNotNone(new_gm_res)
        new_gm = new_gm_res.graph_module

        for nd in new_gm.graph.nodes:
            if nd.target is exir_ops.edge.aten.split_copy.Tensor_out:
                break

        val = nd.meta["val"]

        # We must return a spec which is a list of a signle TensorSpec item.
        # Returning the TensorSpec item directly cause future getitem op fails.
        self.assertTrue(isinstance(val, (tuple, list)))
        self.assertEqual(1, len(val))

    def test_to_out_variant_multiple_out(self) -> None:
        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return torch.topk(x, 5)

            def get_random_inputs(self):
                return (torch.randn(10),)

        model = MyModel()
        inputs = model.get_random_inputs()
        prog = to_edge(
            export(
                model,
                inputs,
            ),
            compile_config=EdgeCompileConfig(_check_ir_validity=False),
        )  # TODO(larryliu): fix topk
        new_gm_res = ToOutVarPass()(prog.exported_program().graph_module)
        self.assertIsNotNone(new_gm_res)
        new_gm = new_gm_res.graph_module

        for nd in new_gm.graph.nodes:
            if nd.target is torch.ops.aten.topk.values:
                break

        val = nd.meta["val"]

        # We must return a spec which is a list of a signle TensorSpec item.
        # Returning the TensorSpec item directly cause future getitem op fails.
        self.assertTrue(isinstance(val, (tuple, list)))
        self.assertEqual(2, len(val))

    def test_to_out_variant_to_copy(self) -> None:
        class Module(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                return x.to(torch.int32)

        model = Module()

        inputs = torch.tensor(1.0, dtype=torch.float)
        model_res = model(inputs)

        edge_dialect = to_edge(
            export(
                model,
                (inputs,),
            )
        )
        edge_res = edge_dialect.exported_program().module()(inputs)
        self.assertTrue(torch.allclose(model_res, edge_res))

    def test_export_pass(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
                y = torch.cat([x, x])
                return torch.ops.aten.tensor_split.sections(y, 2)

        f = Foo()

        class NullPass(ExportPass):
            pass

        prog = to_edge(
            export(
                f,
                (torch.ones(3, 2),),
            ),
            compile_config=EdgeCompileConfig(_check_ir_validity=False),
        )  # TODO(larryliu): fix cat
        new_prog = prog.transform([NullPass()])
        new_nodes = new_prog.exported_program().graph_module.graph.nodes
        for node in new_nodes:
            if node.op != "call_function":
                continue
            self.assertTrue(hasattr(node, "stack_trace"))
            self.assertIsNotNone(node.stack_trace)

        old_nodes = prog.exported_program().graph_module.graph.nodes
        self.assertEqual(len(new_nodes), len(old_nodes))
        for new_node, old_node in zip(new_nodes, old_nodes):
            self.assertEqual(new_node.op, old_node.op)
            self.assertEqual(new_node.target, old_node.target)

    def test_export_pass_pt2(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
                y = torch.cat([x, x])
                return torch.ops.aten.tensor_split.sections(y, 2)

        f = Foo()

        class NullPass(ExportPass):
            pass

        prog = to_edge(
            export(
                f,
                (torch.ones(3, 2),),
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        new_prog = prog.transform([NullPass()])
        new_nodes = new_prog.exported_program().graph_module.graph.nodes
        for node in new_nodes:
            if node.op != "call_function":
                continue
            self.assertTrue(hasattr(node, "stack_trace"))
            self.assertIsNotNone(node.stack_trace)

        old_nodes = prog.exported_program().graph_module.graph.nodes
        self.assertEqual(len(new_nodes), len(old_nodes))
        for new_node, old_node in zip(new_nodes, old_nodes):
            self.assertEqual(new_node.op, old_node.op)
            self.assertEqual(new_node.target, old_node.target)

    def test_export_scalar_to_tensor_pass(self) -> None:
        class Mul(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x * 3.14

        mul = Mul()

        expo_prog = to_edge(export(mul, (torch.ones(1),)))
        new_prog = expo_prog.transform([ScalarToTensorPass()])
        self.assertIsNotNone(new_prog.exported_program().graph_module)
        new_graph_module = new_prog.exported_program().graph_module

        inp = torch.zeros(1)
        self.assertTrue(
            torch.allclose(
                expo_prog.exported_program().module()(inp),
                new_prog.exported_program().module()(inp),
            )
        )
        for node in new_graph_module.graph.nodes:
            if node.op == "call_function":
                for arg in node.args + tuple(node.kwargs.values()):
                    self.assertFalse(isinstance(arg, float))

    def test_remove_mixed_types_symfloats(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.nn.functional.interpolate(
                    x,
                    size=(x.shape[2] * 2, x.shape[3] * 3),
                    mode="bilinear",
                    align_corners=False,
                    antialias=False,
                )

        f = Foo()

        example_inputs = (torch.randn(2, 3, 4, 5),)

        gm = to_edge(
            export(
                f,
                example_inputs,
            )
        )
        new_gm = gm.transform(
            [ReplaceSymSizeOpPass(), ScalarToTensorPass(), RemoveMixedTypeOperators()]
        )
        self.assertIsNotNone(new_gm.exported_program().graph_module)

        self.assertTrue(
            torch.allclose(
                gm.exported_program().module()(*example_inputs),
                new_gm.exported_program().module()(*example_inputs),
            )
        )

    def test_spec_prop_pass(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x + x

        f = Foo()

        gm = (
            to_edge(
                export(
                    f,
                    (torch.ones(3, 2),),
                )
            )
            .exported_program()
            .graph_module
        )
        new_gm = SpecPropPass()(gm)
        self.assertIsNotNone(new_gm)
        new_nodes = new_gm.graph_module.graph.nodes
        counter = 0
        for node in new_nodes:
            if node.op != "output":
                continue
            counter += 1
            self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])

        self.assertEqual(counter, 1)

    def test_spec_prop_pass_tuple_output(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
                return (x + x,)

        f = Foo()

        gm = (
            to_edge(
                export(
                    f,
                    (torch.ones(3, 2),),
                )
            )
            .exported_program()
            .graph_module
        )
        new_gm = SpecPropPass()(gm)
        self.assertIsNotNone(new_gm)
        new_nodes = new_gm.graph_module.graph.nodes
        counter = 0
        for node in new_nodes:
            if node.op != "output":
                continue
            counter += 1
            self.assertIs(node.meta["spec"][0], node.args[0][0].meta["spec"])

        self.assertEqual(counter, 1)

    def test_compile_fix_broken_ops(self) -> None:
        # When pass an input of more than 4 dimensions to Linear
        # aten._unsafe_view is used under the hood
        x = torch.randn([2, 3, 4, 5])
        model: torch.nn.Linear = torch.nn.Linear(5, 5)

        class Foo(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.model = model

            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                return self.model(inp)

        f = Foo()

        # ReplaceBrokenOpsWithFunctionalOpsPass is used in to_edge()
        prog = to_edge(
            export(
                f,
                (x,),
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        gm = prog.exported_program().graph_module
        count_after = 0
        for node in gm.graph.nodes:
            if node.target == torch.ops.aten._unsafe_view.default:
                count_after += 1
        self.assertEqual(count_after, 0)
        self.assertTrue(torch.allclose(prog.exported_program().module()(x), f(x)))

    def test_convert_symb_ops(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.add(x, x.shape[0] - 1)

        f = Foo()

        # Mark the 0th dimension of X as dynamic with a max value of 3.
        dim_x = torch.export.Dim("dim_x", max=3)

        prog = to_edge(
            export(
                f,
                (torch.ones(3, 2),),
                dynamic_shapes={"x": {0: dim_x}},
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        new_prog = prog.transform([EdgeToBackendOpsPass()])
        self.assertIsNotNone(new_prog.exported_program().graph_module)
        converted_gm = new_prog.exported_program().graph_module

        FileCheck().check("torch.ops.aten.sym_size.int").check(
            "executorch_exir_dialects_backend__ops_executorch_prim_sub_Scalar"
        ).check_not("operator.sub").run(converted_gm.code)

    def test_alloc_node_spec(self) -> None:
        """
        Make sure every memory.alloc node including those in sub graph modules
        have a TensorSpec.
        """
        eager_model = FTMapBasic()
        inputs = eager_model.get_random_inputs()
        prog = to_edge(
            export(
                eager_model,
                inputs,
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        passes = [
            SpecPropPass(),
            HintBasedSymShapeEvalPass(),
        ]
        new_prog = prog.transform(passes)

        new_gm_res = ToOutVarPass()(new_prog.exported_program().graph_module)
        self.assertIsNotNone(new_gm_res)
        new_gm = new_gm_res.graph_module

        new_gm_res = MemoryPlanningPass()(new_gm)
        self.assertIsNotNone(new_gm_res)
        new_gm = new_gm_res.graph_module

        alloc_nodes = []
        for subgm in new_gm.modules():
            if isinstance(subgm, torch.fx.GraphModule):
                for node in subgm.graph.nodes:
                    if node.target == memory.alloc:
                        alloc_nodes.append(node)
        self.assertTrue(len(alloc_nodes) > 0)
        for node in alloc_nodes:
            self.assertTrue(isinstance(node.meta.get("spec", None), TensorSpec))

    def test_debug_pass_file_log(self) -> None:
        eager_model = Mul()
        inputs = eager_model.get_random_inputs()

        # the debug pass works with a graph generated with make_fx directly
        gm = make_fx(eager_model)(*inputs)

        try:
            fd, path = tempfile.mkstemp()

            print(f"Write DebugPass output to {path}")
            DebugPass(log_filename=path)(gm)
            with open(path) as f:
                file_cont = f.read()
            self.assertTrue("torch.ops.aten.mul" in file_cont)
        finally:
            os.close(fd)
            os.unlink(path)

    def test_dce_recursive(self) -> None:
        eager_model = FTCondDeadCode()
        inputs = eager_model.get_random_inputs()
        gm = export(
            eager_model,
            inputs,
        ).graph_module

        self.assertTrue(torch.ops.aten.sub.Tensor in collect_ops(gm))
        dead_code_elimination_pass(gm)
        gm.print_readable()
        self.assertFalse(torch.ops.aten.sub.Tensor in collect_ops(gm))

    def test_propagate_dynamic_shape(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                y = x
                for _ in range(2):
                    y = y + x
                return y

        f = Foo()

        prog = to_edge(
            export(
                f,
                (torch.rand(5),),
            ),
            # missing dispatch key
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        ).transform(propagate_dynamic_shape())
        gm = prog.exported_program().graph_module
        nspec = 0
        for n in gm.graph.nodes:
            for spec in pytree.tree_flatten(n.meta["spec"])[0]:
                self.assertTrue(all(isinstance(x, int) for x in spec.shape))
                nspec += 1

        self.assertTrue(nspec > 0)

    def test_losing_symbolic_info(self) -> None:
        """
        Guard against an issue that after calling ConvertSymbolicOpsPass(),
        future ExportPass will encounter symbolic information loss.
        """

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.add(x, x.shape[0] - 1)

        f = Foo()

        dim_x = torch.export.Dim("dim_x", max=3)
        prog = to_edge(
            export(
                f,
                (torch.ones(3, 2),),
                dynamic_shapes={"x": {0: dim_x}},
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )

        new_prog = prog.transform([EdgeToBackendOpsPass()])
        gm = new_prog.exported_program().graph_module
        gm.print_readable()
        *_, ones, out = gm.graph.nodes
        print(f"Before ExportPass: {ones.format_node()}")
        self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
        self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)

        new_prog = new_prog.transform([ExportPass()])
        gm = new_prog.exported_program().graph_module
        gm.print_readable()
        *_, ones, out = gm.graph.nodes
        print(f"After ExportPass: {ones.format_node()}")
        self.assertTrue(isinstance(ones.meta["val"].shape[0], torch.SymInt))
        self.assertTrue(len(ones.meta["val"].shape[0].node.expr.free_symbols) > 0)

    def test_to_edge_with_edge_ops(self) -> None:
        x = torch.randn([2, 3, 4, 5])

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x + x

        f = Foo()

        gm = (
            to_edge(
                export(
                    f,
                    (x,),
                )
            )
            .exported_program()
            .graph_module
        )
        for node in gm.graph.nodes:
            if node.op == "call_function":
                self.assertEqual(type(node.target), EdgeOpOverload)

    # TODO(T143084047)
    @unittest.expectedFailure
    def test_backend_fused_op_retraceable(self) -> None:
        """This test makes sure the backend op is still retraceable, with the pattern being registered as kernel."""

        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                z = x + y
                return torch.ops.aten.relu.default(z)

        f = Foo()

        gm = export(
            f,
            (
                torch.randn(2, 2),
                torch.randn(2, 2),
            ),
        )
        # should look like:
        # graph():
        #     %ph_0 : [#users=1] = placeholder[target=ph_0]
        #     %ph_1 : [#users=1] = placeholder[target=ph_1]
        #     %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, %ph_1), kwargs = {})
        #     %relu_default : [#users=1] = call_function[target=torch.ops.aten.relu.default](args = (%add_tensor,), kwargs = {})
        #     return [relu_default]
        FileCheck().check("torch.ops.aten.add.Tensor").check(
            "torch.ops.aten.relu.default"
        ).run(gm.graph_module.code)

        class AddReluFusionPass(ExportPass):
            def call(self, graph_module: GraphModule) -> PassResult:
                # decorator registers this pattern as a CompositeExplicitAutograd kernel, since there's no kernel registered before.
                @bind_pattern_to_op(lib, "add_relu")
                def pattern(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    z = torch.ops.aten.add.Tensor(x, y)
                    out = torch.ops.aten.relu.default(z)
                    return out

                def replacement(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default(x, y)

                subgraph_rewriter.replace_pattern(graph_module, pattern, replacement)
                return PassResult(graph_module, True)

        # TODO: larryliu this pass needs to be in to_executorch()
        class OpReplacePass(ExportPass):
            def call_operator(self, op, args, kwargs, meta):
                if op == torch.ops.DO_NOT_USE_TEST_ONLY.add_relu.default:
                    return super().call_operator(
                        ops.backend.DO_NOT_USE_TEST_ONLY.add_relu.default,
                        args,
                        kwargs,
                        meta,
                    )
                return super().call_operator(op, args, kwargs, meta)

        gm_lowered = to_edge(
            gm,
            compile_config=EdgeCompileConfig(
                _check_ir_validity=False,
            ),
        ).transform([AddReluFusionPass(), OpReplacePass()])

        FileCheck().check(
            "executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default"
        ).run(gm_lowered.exported_program().graph_module.code)
        # lowered module:
        # def forward(self, ph_0, ph_1):
        #     do_not_use_test_only_add_relu_default = executorch_exir_dialects_backend__ops_DO_NOT_USE_TEST_ONLY_add_relu_default(ph_0, ph_1);  ph_0 = ph_1 = None
        #     return [do_not_use_test_only_add_relu_default]

        # Retrace:
        # If not backend op retrace will error out because no CPU/CompositeExplicitAutograd kernel registered.
        gm_retraced = to_edge(
            export(
                gm_lowered.exported_program().module(),
                (
                    torch.randn(2, 2),
                    torch.randn(2, 2),
                ),
            )
        )
        # Retrace-able, the graph "promote" back to ATen dialect, showing up add and relu, which is expected.
        FileCheck().check("torch.ops.aten.add.Tensor").check(
            "torch.ops.aten.relu.default"
        ).run(gm_retraced.exported_program().graph_module.code)

    def test_debug_handle_generator_pass(self) -> None:
        eager_model = MLP(2, output_size=4)
        inputs = eager_model.get_random_inputs()

        graph_module = (
            to_edge(
                export(
                    eager_model,
                    inputs,
                )
            )
            .exported_program()
            .graph_module
        )
        for node in graph_module.graph.nodes:
            self.assertIn("debug_handle", node.meta)
        ScalarToTensorPass()(graph_module)
        for node in graph_module.graph.nodes:
            self.assertIn("debug_handle", node.meta)

    def test_generate_missing_debug_handles(self) -> None:
        eager_model = MLP(2, output_size=4)
        inputs = eager_model.get_random_inputs()

        ep = to_edge(
            export(
                eager_model,
                inputs,
            )
        ).exported_program()

        list(ep.graph.nodes)[0].meta.pop("debug_handle")
        self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is None)
        generate_missing_debug_handles(ep)
        self.assertTrue(list(ep.graph.nodes)[0].meta.get("debug_handle") is not None)

    def test_debug_handle_generator_pass_with_control_flow(self) -> None:
        def true_nested(y: torch.Tensor) -> torch.Tensor:
            y = y + y
            y = torch.mm(y, y)
            return y

        def false_nested(y: torch.Tensor) -> torch.Tensor:
            return torch.mm(y, y)

        def true_fn(x: torch.Tensor, pred2: torch.Tensor) -> torch.Tensor:
            z = control_flow.cond(pred2, true_nested, false_nested, [x])
            return x + z

        def false_fn(x: torch.Tensor, _) -> torch.Tensor:
            return x.cos()

        def map_fn(
            x: torch.Tensor, pred1: torch.Tensor, pred2: torch.Tensor, y: torch.Tensor
        ) -> torch.Tensor:
            x = x.cos()
            y = control_flow.cond(pred1, true_fn, false_fn, [y, pred2])
            x = x + y
            return x.sin()

        class Foo(torch.nn.Module):
            def forward(
                self,
                xs: torch.Tensor,
                pred1: torch.Tensor,
                pred2: torch.Tensor,
                y: torch.Tensor,
            ) -> torch.Tensor:
                y = torch.mm(y, y)
                return control_flow.map(map_fn, xs, pred1, pred2, y)

        f = Foo()

        inputs = (
            torch.ones(2, 2),
            torch.tensor([False]),
            torch.tensor([False]),
            torch.ones(2, 2),
        )

        ep = to_edge(
            export(
                f,
                inputs,
            )
        ).exported_program()
        graph_module = ep.graph_module

        def check_debug_handle_metadata(graph_module: torch.fx.GraphModule) -> None:
            queue = [graph_module]
            while queue:
                current_graph_module = queue.pop(0)
                for node in current_graph_module.graph.nodes:
                    self.assertIn("debug_handle", node.meta)
                control_flow_submodules = [
                    submodule
                    for _, submodule, _ in get_control_flow_submodules(
                        current_graph_module
                    )
                ]
                queue.extend(control_flow_submodules)

        DebugHandleGeneratorPass()(graph_module)
        check_debug_handle_metadata(graph_module)
        generate_missing_debug_handles(ep)

        # Check debug handle still preserved after ScalarToTensorPass
        ScalarToTensorPass()(graph_module)
        check_debug_handle_metadata(graph_module)

    def test_symint_conversion(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return torch.add(x, x.shape[0] - 1)

        f = Foo()

        dim_x = torch.export.Dim("dim_x", max=3)
        prog = to_edge(
            export(
                f,
                (torch.ones(3, 2),),
                dynamic_shapes={"x": {0: dim_x}},
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        prog = prog.transform([SymToTensorPass()])

        FileCheck().check("torch.ops.aten.scalar_tensor.default").run(
            prog.exported_program().graph_module.code
        )
        self.assertTrue(
            torch.allclose(
                f(torch.ones(3, 2)), prog.exported_program().module()(torch.ones(3, 2))
            )
        )
        self.assertTrue(
            torch.allclose(
                f(torch.zeros(3, 2)),
                prog.exported_program().module()(torch.zeros(3, 2)),
            )
        )

    def test_remove_assert_pass(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                assert x.shape[0] == 5
                return x * x

        f = Foo()

        gm = to_edge(
            export(
                f,
                (torch.randn(5),),
            ),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        new_gm = gm.transform([RemoveGraphAssertsPass()])
        num_asserts = [
            node
            for node in new_gm.exported_program().graph.nodes
            if node.op == "call_function"
            and node.target == torch.ops.aten._assert_async.msg
        ]
        self.assertEqual(len(num_asserts), 0)

    def test_arange(self) -> None:
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.ones(2)

            def forward(self, x):
                return torch.arange(start=0, end=2) + x

        _ = to_edge(
            export(
                M(),
                (torch.randn(2),),
            )
        ).to_executorch()

    def test_replace_slice(self) -> None:
        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.ones(10)

            def forward(self, x):
                return self.a[:2] + x

        gm = (
            to_edge(
                export(
                    M(),
                    (torch.randn(2),),
                )
            )
            .exported_program()
            .graph_module
        )
        FileCheck().check(
            "executorch_exir_dialects_edge__ops_aten_slice_copy_Tensor"
        ).run(gm.code)

    def test_constant_prop_pass_for_add(self) -> None:
        class Add(torch.nn.Module):
            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x + 3

        add = Add()

        edge = to_edge(
            export(add, (torch.ones(1),)),
            compile_config=EdgeCompileConfig(_skip_dim_order=False),
        )
        edge = edge.transform([ScalarToTensorPass(), RemoveMixedTypeOperators()])
        exported_program = lift_constant_tensor_pass(edge.exported_program())

        # Check there is a lifted tensor followed by a to_copy node
        FileCheck().check("_lifted_tensor_constant0").check(
            "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
        ).run(exported_program.graph_module.code)

        new_ep = constant_prop_pass(exported_program)

        # Check (_lifted_tensor_constant + to_copy) node is replaced by prop tensor
        FileCheck().check_not("_lifted_tensor_constant").check(
            "_prop_tensor_constant0"
        ).check_not(
            "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"
        ).run(
            new_ep.graph_module.code
        )

    def test_constant_prop_pass_for_parameter(self) -> None:
        def count_additions(gm: torch.fx.GraphModule) -> int:
            return sum(
                (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
            )

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.nn.Parameter(torch.ones(1, 2, 3))

            def forward(self, x):
                b = self.a + self.a
                c = torch.cat([self.a, b])
                return (c + c) + x

        aten = export(
            M(),
            (torch.zeros(2, 2, 3),),
        )
        self.assertEqual(count_additions(aten.graph_module), 3)
        new_ep = constant_prop_pass(aten)
        self.assertEqual(count_additions(new_ep.graph_module), 1)

    def test_constant_prop_pass_graph_signature(self) -> None:
        def count_additions(gm: torch.fx.GraphModule) -> int:
            return sum(
                (node.target == torch.ops.aten.add.Tensor) for node in gm.graph.nodes
            )

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.nn.Parameter(torch.ones(1, 2, 3))

            def forward(self, x):
                b = self.a + self.a
                c = torch.cat([self.a, b])
                return (c + c) + x

        aten = export(
            M(),
            (torch.zeros(2, 2, 3),),
        )
        # Input signature will have two entries:
        # (1) parameter `a` and (2) user input `x`.
        self.assertEqual(len(aten.graph_signature.input_specs), 2)
        new_ep = constant_prop_pass(aten)
        # Check that there are exactly two propagated tensors - (1) propagated
        # constant and (2) user input.
        self.assertEqual(
            new_ep.graph_signature.input_specs,
            [
                InputSpec(
                    kind=InputKind.CONSTANT_TENSOR,
                    arg=TensorArgument(name="_prop_tensor_constant0"),
                    target="_prop_tensor_constant0",
                    persistent=True,
                ),
                # User input graph signature.
                aten.graph_signature.input_specs[-1],
            ],
        )

    def test_constant_prop_pass_for_parameter_slice(self) -> None:
        def count_slice(gm: torch.fx.GraphModule) -> int:
            return sum(
                (node.target == torch.ops.aten.slice_copy.Tensor)
                for node in gm.graph.nodes
            )

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.nn.Parameter(torch.ones(3, 2, 2))

            def forward(self, x):
                # Create slice of shape (1, 2, 2)
                slice_tensor = torch.slice_copy(self.a, dim=0, start=0, end=1)
                return torch.cat([x, slice_tensor])

        aten = export(
            M(),
            (torch.zeros(2, 2, 2),),
        )
        self.assertIn("a", aten.state_dict)
        self.assertEqual(count_slice(aten.graph_module), 1)

        new_ep = constant_prop_pass(aten)
        # Check there is a propagated tensor.
        FileCheck().check("_prop_tensor_constant0").run(aten.graph_module.code)
        self.assertIn("_prop_tensor_constant0", new_ep.constants)
        self.assertNotIn("a", new_ep.state_dict)
        # No more slice copy.
        self.assertEqual(count_slice(new_ep.graph_module), 0)

    def test_constant_prop_pass_no_propagate(self) -> None:
        def count_placeholder(gm: torch.fx.GraphModule) -> int:
            return sum((node.op == "placeholder") for node in gm.graph.nodes)

        class M(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.a = torch.nn.Parameter(torch.ones(3, 2, 4))

            def forward(self, x, y):
                # y is unused.
                return x + self.a

        aten = export(
            M(),
            (torch.zeros(3, 2, 4), torch.zeros(3, 2, 4)),
        )
        self.assertIn("a", aten.state_dict)
        self.assertEqual(count_placeholder(aten.graph_module), 3)

        new_ep = constant_prop_pass(aten)
        # Check there is no propagated tensor.
        FileCheck().check("p_a").check("x").check("y").run(aten.graph_module.code)
        self.assertNotIn("_prop_tensor_constant0", new_ep.constants)
        self.assertIn("a", new_ep.state_dict)
        self.assertEqual(count_placeholder(new_ep.graph_module), 3)

    def test_constant_prop_pass_for_control_flow(self) -> None:
        class Module(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = torch.nn.Linear(3, 3)

            def t(self, val):
                return val + 1

            def f(self, val):
                return val - 1

            def true_fn(self, val):
                return self.linear(val) + self.t(val)

            def false_fn(self, val):
                return self.linear(val) - self.f(val)

            def forward(self, pred, x):
                return torch.ops.higher_order.cond(
                    pred, self.true_fn, self.false_fn, [x]
                )

        mod = Module()
        x = torch.randn([3, 3])
        pred = torch.tensor(x[0][0].item() < 0)
        edge = to_edge(
            export(mod, (pred, x)),
            compile_config=exir.EdgeCompileConfig(_check_ir_validity=False),
        )
        error_msg = r"constant_prop_pass for control flow is not supported yet."

        # TODO(chenlai): enable constant prop pass for control flow
        with self.assertRaisesRegex(
            RuntimeError,
            error_msg,
        ):
            _ = constant_prop_pass(edge.exported_program())

    def test_mutable_buffers(self) -> None:
        def count_copies(gm: torch.fx.GraphModule) -> int:
            return sum(
                (node.target == torch.ops.aten.copy_.default) for node in gm.graph.nodes
            )

        class MutableStateModule(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.register_buffer("state", torch.zeros(1))

            def forward(self, x):
                y = x + self.state
                self.state.add_(1)
                return y

        model = to_edge(
            export(
                MutableStateModule(),
                (torch.zeros(1),),
            )
        )
        self.assertEqual(count_copies(model.exported_program().graph_module), 0)
        # Before
        # graph():
        #     %arg0_1 : [num_users=2] = placeholder[target=arg0_1]
        #     %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
        #     %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
        #     %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
        #     %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
        #     %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
        #     return (aten_add_tensor_1, aten_add_tensor)
        gm, _ = insert_write_back_for_buffers_pass(model.exported_program())

        # After
        # graph():
        #     %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
        #     %_lifted_tensor_constant1 : [num_users=1] = placeholder[target=_lifted_tensor_constant1]
        #     %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
        #     %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg1_1, %arg0_1), kwargs = {})
        #     %aten__to_copy_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%_lifted_tensor_constant1,), kwargs = {dtype: torch.float32})
        #     %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%arg0_1, %aten__to_copy_default), kwargs = {})
        #     %copy__default : [num_users=1] = call_function[target=torch.ops.aten.copy_.default](args = (%arg0_1, %aten_add_tensor_1), kwargs = {})
        #     return (copy__default, aten_add_tensor)
        self.assertEqual(count_copies(gm), 1)

    def test_remove_quantized_op_noop_pass(self) -> None:
        class TestAddSliceNoop(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x = x + x
                x = x + x[:]
                return x

        class TestAddSliceNotNoop(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                x = x + x
                x = x + x[:1]
                return x

        def count_dq_nodes(gm: torch.fx.GraphModule) -> int:
            return sum(
                (
                    node.target
                    in (
                        torch.ops.quantized_decomposed.dequantize_per_tensor.default,
                        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
                    )
                )
                for node in gm.graph.nodes
            )

        def count_q_nodes(gm: torch.fx.GraphModule) -> int:
            return sum(
                (
                    node.target
                    in (
                        torch.ops.quantized_decomposed.quantize_per_tensor.default,
                        exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
                    )
                )
                for node in gm.graph.nodes
            )

        def quantize_model(
            m_eager: torch.nn.Module, example_inputs: Tuple[torch.Tensor]
        ) -> Tuple[EdgeProgramManager, int, int]:
            # program capture
            m = torch.export.export_for_training(
                m_eager,
                example_inputs,
            ).module()

            quantizer = XNNPACKQuantizer()
            quantization_config = get_symmetric_quantization_config()
            quantizer.set_global(quantization_config)
            m = prepare_pt2e(m, quantizer)  # pyre-fixme[6]
            m = convert_pt2e(m, fold_quantize=True)
            ep = torch.export.export(m, example_inputs)
            dq_nodes_pre = count_dq_nodes(ep.graph_module)
            q_nodes_pre = count_q_nodes(ep.graph_module)
            edge = to_edge(
                ep, compile_config=EdgeCompileConfig(_check_ir_validity=False)
            )
            return edge, dq_nodes_pre, q_nodes_pre

        example_inputs = (torch.randn(9, 8),)
        model = TestAddSliceNoop()
        m_eager = model.eval()
        edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)

        dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
        q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
        # One dq and one q node around the slice copy should have been removed.
        self.assertEqual(dq_nodes_pre - dq_nodes_post, 1)
        self.assertEqual(q_nodes_pre - q_nodes_post, 1)

        # Check that the slice_copy is removed by the RemoveNoopPass.
        for node in edge.exported_program().graph_module.graph.nodes:
            self.assertFalse("slice" in str(node.target))

        model = TestAddSliceNotNoop()
        m_eager = model.eval()
        edge, dq_nodes_pre, q_nodes_pre = quantize_model(m_eager, example_inputs)

        dq_nodes_post = count_dq_nodes(edge.exported_program().graph_module)
        q_nodes_post = count_q_nodes(edge.exported_program().graph_module)
        # One dq and one q node around the slice copy should have been removed.
        self.assertEqual(dq_nodes_pre, dq_nodes_post)
        self.assertEqual(q_nodes_pre, q_nodes_post)

        # Check that the slice_copy is not removed by the RemoveNoopPass.
        self.assertTrue(
            any(
                "slice" in str(node.target)
                for node in edge.exported_program().graph_module.graph.nodes
            )
        )

    def test_dq_q_no_op_pass(self) -> None:
        class TestDqQ(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
                    x, 1.0, 0, -128, 127, torch.int8
                )
                q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
                    dq, 1.0, 0, -128, 127, torch.int8
                )
                return q

        model = TestDqQ()
        m_eager = model.eval()
        ep = torch.export.export(m_eager, (torch.randn(9, 8),))
        edge = to_edge(ep)
        # Check that the dq and q nodes are not touched by the RemoveNoopPass.
        self.assertTrue(
            any(
                "dequantize" in str(node.target)
                for node in edge.exported_program().graph_module.graph.nodes
            )
        )
        self.assertTrue(
            any(
                "quantize" in str(node.target)
                for node in edge.exported_program().graph_module.graph.nodes
            )
        )

    def test_dq_q_different_qparams(self) -> None:
        class TestDqQDifferentQParam(torch.nn.Module):
            def __init__(self):
                super().__init__()

            def forward(self, x):
                dq = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
                    x, 1.0, 0, -128, 127, torch.int8
                )
                slice_copy_output = torch.ops.aten.slice_copy.Tensor(dq, 0, 0)
                q = torch.ops.quantized_decomposed.quantize_per_tensor.default(
                    slice_copy_output, 1.0, 0, -127, 127, torch.int8
                )
                return q

        model = TestDqQDifferentQParam()
        m_eager = model.eval()
        ep = torch.export.export(m_eager, (torch.randn(9, 8),))
        edge = to_edge(ep)
        print(edge.exported_program().graph_module.graph)
        # Check that the dq and q nodes are not touched by the RemoveNoopPass.
        self.assertTrue(
            any(
                "dequantize" in str(node.target)
                for node in edge.exported_program().graph_module.graph.nodes
            )
        )
        self.assertTrue(
            any(
                "quantize" in str(node.target)
                for node in edge.exported_program().graph_module.graph.nodes
            )
        )
        self.assertFalse(
            any(
                "slice" in str(node.target)
                for node in edge.exported_program().graph_module.graph.nodes
            )
        )

    def test_normalize_view_copy_base_pass(self) -> None:

        class ViewChain(torch.nn.Module):
            def forward(self, x):
                x = torch.ops.aten.view_copy.default(x, [30, 1])
                x = torch.ops.aten.view_copy.default(x, [5, 6])
                x = torch.ops.aten.view_copy.default(x, [2, 15])
                x = torch.ops.aten.view_copy.default(x, [3, -1])
                return x

        def is_view_copy(node: torch.fx.Node) -> bool:
            return (
                node.op == "call_function"
                and node.target == torch.ops.aten.view_copy.default
            )

        gm = export(ViewChain(), (torch.ones(30),)).graph_module

        # Check before transformation
        n_view_copy_before = 0
        n_view_copy_bases_before = 0
        for node in gm.graph.nodes:
            if is_view_copy(node):
                n_view_copy_before += 1
                base = node.args[0]
                if is_view_copy(base):
                    n_view_copy_bases_before += 1

        self.assertEqual(n_view_copy_before, 4)
        self.assertEqual(n_view_copy_bases_before, 3)

        # Do transformation
        p = NormalizeViewCopyBasePass()
        gm_res = p(gm)
        assert gm_res is not None
        gm = gm_res.graph_module

        # Check after transformation
        n_view_copy_after = 0
        n_view_copy_bases_after = 0
        for node in gm.graph.nodes:
            if is_view_copy(node):
                n_view_copy_after += 1
                base = node.args[0]
                if is_view_copy(base):
                    n_view_copy_bases_after += 1

        self.assertEqual(n_view_copy_after, 4)
        self.assertEqual(n_view_copy_bases_after, 0)

    def test_replace_view_copy_with_view_pass(self) -> None:  # noqa: C901

        # Helper functions
        def is_view_copy(node: torch.fx.Node) -> bool:
            return (
                node.op == "call_function"
                and node.target == torch.ops.aten.view_copy.default
            )

        def is_memory_view(node: torch.fx.Node) -> bool:
            return node.op == "call_function" and node.target == memory.view

        # Test example set up
        class TestViewCopies(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.parameter = torch.nn.Parameter(torch.ones(1))

            def forward(self, x):
                o1 = torch.ops.aten.view_copy.default(x, [1])
                o2 = torch.ops.aten.view_copy.default(self.parameter, [1])
                # view_copys at the end of a function are not replaced, so add
                # a computation before the end of the graph.
                return torch.ops.aten.add.Tensor(o1, o2)

        ep = torch.export.export(
            TestViewCopies(),
            args=(torch.ones(1),),
        )
        for node in ep.graph.nodes:
            if node.op == "placeholder":
                node.meta["spec"] = TensorSpec.from_tensor(torch.empty(1))
                node.meta["spec"].shape_dynamism = TensorShapeDynamism.STATIC

        # Run tests
        gm = ep.graph_module

        # Check before transformation
        FileCheck().check_count(
            "torch.ops.aten.view_copy.default", 2, exactly=True
        ).run(gm.code)
        FileCheck().check_count("executorch_exir_memory_view", 0, exactly=True).run(
            gm.code
        )

        # Do transformation
        p = ReplaceViewCopyWithViewPass()
        gm_res = p(gm)
        assert gm_res is not None
        gm = gm_res.graph_module

        # Check after transformation
        FileCheck().check_count(
            "torch.ops.aten.view_copy.default", 0, exactly=True
        ).run(gm.code)
        FileCheck().check_count("executorch_exir_memory_view", 2, exactly=True).run(
            gm.code
        )

    def test_constant_prop_pass_for_no_grad(self) -> None:
        class LSTM(torch.nn.Module):
            def __init__(self, input_size, hidden_size, num_layers):
                super(LSTM, self).__init__()
                self.hidden_size = hidden_size
                self.num_layers = num_layers
                self.lstm = torch.nn.LSTM(
                    input_size, hidden_size, num_layers, batch_first=True
                )

            def forward(self, text_tokens):
                # input: (seq_len, batch, input_size)
                lstm_out, (new_hidden_state, new_cell_state) = self.lstm(
                    input=text_tokens, hx=None
                )
                return lstm_out

        lstm = LSTM(input_size=200, hidden_size=203, num_layers=2)
        example_input = (torch.rand(2, 10, 200),)

        aten = torch.export.export(lstm, example_input, strict=False)
        _EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
            _check_ir_validity=True,
            _skip_dim_order=True,  # TODO(T189114319): Reuse dim order op after solving the ios oss issue
        )

        edge_manager: EdgeProgramManager = to_edge(
            aten,
            compile_config=_EDGE_COMPILE_CONFIG,
        )
        new_ep = constant_prop_pass(edge_manager._edge_programs["forward"])
        _ = copy.deepcopy(new_ep.module_call_graph)

    def test_dim_order_revert_pass(self) -> None:
        aten_op_str = "torch.ops.aten._to_copy.default"
        edge_aten_op_str = "executorch_exir_dialects_edge__ops_aten__to_copy_default"
        edge_dim_order_op_str = "executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default"

        class Module(torch.nn.Module):
            """
            A simple module that has a single to op that converts to channels last and then back to contiguous.
            Assuming contiguous input.
            """

            def __init__(self):
                super().__init__()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return x.to(memory_format=torch.channels_last).to(
                    memory_format=torch.contiguous_format
                ) + x.to(memory_format=torch.channels_last).to(
                    memory_format=torch.contiguous_format
                )

            @staticmethod
            def to_copy_count():
                return 4

        def _do_checks(
            test_str: str, allowed: str, allowed_count: int, not_allowed_list: List[str]
        ) -> None:
            for not_allowed in not_allowed_list:
                FileCheck().check_count(allowed, allowed_count, exactly=True).check_not(
                    not_allowed
                ).run(test_str)

        m = Module()
        n = m.to_copy_count()
        input = torch.randn([2, 3, 4, 5]).to(memory_format=torch.contiguous_format)

        # 1. vanilla export, no edge ops
        ep = export(
            m,
            (input,),
        ).run_decompositions({})
        _do_checks(
            ep.graph_module.code,
            aten_op_str,
            n,
            [edge_aten_op_str, edge_dim_order_op_str],
        )

        # 2a. to edge without dim orders, we should see edge aten ops but not dim order ops
        edge_prog = to_edge(
            ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=True)
        )._edge_programs["forward"]
        _do_checks(
            edge_prog.graph_module.code,
            edge_aten_op_str,
            n,
            [aten_op_str, edge_dim_order_op_str],
        )

        # 3a. expect no change after the pass, we should see edge aten ops but not dim order ops
        new_res = DimOrderOpsRevertPass()(edge_prog.graph_module)
        self.assertIsNotNone(new_res)
        _do_checks(
            new_res.graph_module.code,
            edge_aten_op_str,
            n,
            [aten_op_str, edge_dim_order_op_str],
        )

        # 2b. let's try with dim order enabled, we should see edge dim order ops but not edge aten ops
        edge_prog_dim_order = to_edge(
            ep, compile_config=exir.EdgeCompileConfig(_skip_dim_order=False)
        )._edge_programs["forward"]
        _do_checks(
            edge_prog_dim_order.graph_module.code,
            edge_dim_order_op_str,
            n,
            [aten_op_str, edge_aten_op_str],
        )

        # 3b. expect edge aten ops after the pass, we should see not see the edge dim order ops
        new_res_dim_order = DimOrderOpsRevertPass()(edge_prog_dim_order.graph_module)
        self.assertIsNotNone(new_res_dim_order)
        _do_checks(
            new_res_dim_order.graph_module.code,
            edge_aten_op_str,
            n,
            [aten_op_str, edge_dim_order_op_str],
        )

        output_no_dim_order = new_res.graph_module(input)
        output_no_dim_order_revert = new_res_dim_order.graph_module(input)
        self.assertTrue(
            torch.allclose(output_no_dim_order[0], output_no_dim_order_revert[0])
        )
