# Owner(s): ["oncall: jit"]

import io
import unittest
from itertools import product
from typing import Any

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit._recursive import wrap_cpp_module
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_CUDNN
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
from torch.testing._internal.common_quantized import override_quantized_engine
from torch.testing._internal.common_utils import (
    set_default_dtype,
    skipCUDAMemoryLeakCheckIf,
    skipIfTorchDynamo,
    TEST_WITH_ROCM,
)
from torch.testing._internal.jit_utils import JitTestCase
from torch.utils import mkldnn as mkldnn_utils


try:
    import torchvision

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )

TEST_ROCM = torch.cuda.is_available() and torch.version.hip is not None


def removeExceptions(graph):
    for n in graph.findAllNodes("prim::RaiseException"):
        n.destroy()


class TestFreezing(JitTestCase):
    def test_freeze_module(self):
        class M(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 1  # folded
                self.b = 1.2  # folded
                self.c = "hello"  # folded
                self.c2 = "hi\xA1"  # not folded
                self.d = [1, 1]  # folded
                self.e = [1.0, 1.1]  # folded
                self.f = ["hello", "world"]  # folded
                self.f2 = [(1, "Over \u0e55\u0e57 57")]
                self.g = (
                    [1, 2],
                    3.2,
                    "4.4",
                    torch.tensor([5.5], requires_grad=True),
                )  # folded
                self.h = {"layer": [torch.tensor([7.7], requires_grad=True)]}
                self.h2 = {"layer\xB1": [torch.tensor([8.8], requires_grad=True)]}
                self.t = torch.tensor([1.2, 2.4], requires_grad=True)  # folded
                self.ts = [
                    torch.tensor([1.0, 2.0], requires_grad=True),
                    torch.tensor([3.0, 4.0], requires_grad=True),
                ]  # folded
                self.tt = [[torch.tensor([3.3, 2.3], requires_grad=True), None]]

            def forward(self, x):
                return (
                    str(self.a)
                    + str(self.b)
                    + self.c
                    + self.c2
                    + str(self.d)
                    + str(self.e)
                    + str(self.f)
                    + str(self.f2)
                    + str(self.g)
                    + str(self.h)
                    + str(self.h2)
                    + str(self.t)
                    + str(self.ts)
                    + str(self.tt)
                )

        m = torch.jit.script(M())
        m.eval()
        input = torch.randn(2, 2)
        output_s = m.forward(input)
        m._c = torch._C._freeze_module(m._c)
        buffer = io.BytesIO()
        torch.jit.save(m._c, buffer)
        buffer.seek(0)
        m2 = torch.jit.load(buffer)
        # Check if frozen module looks as below:
        # module m {
        #   attributes {
        #     tt = ...
        #   }
        #   ...
        # }
        self.assertFalse(m2._c.hasattr("a"))
        self.assertFalse(m2._c.hasattr("b"))
        self.assertFalse(m2._c.hasattr("c"))
        self.assertFalse(m2._c.hasattr("c2"))
        self.assertFalse(m2._c.hasattr("d"))
        self.assertFalse(m2._c.hasattr("e"))
        self.assertFalse(m2._c.hasattr("f"))
        self.assertFalse(m2._c.hasattr("f2"))
        self.assertFalse(m2._c.hasattr("g"))
        self.assertFalse(m2._c.hasattr("h"))
        self.assertFalse(m2._c.hasattr("h2"))
        self.assertFalse(m2._c.hasattr("t"))
        self.assertFalse(m2._c.hasattr("ts"))
        self.assertFalse(m2._c.hasattr("tt"))
        output_f = m2.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_submodule(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 11
                self.b = 2

            def forward(self, x):
                return self.a + self.b

        class SubModule2(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 12
                self.b = 2

            def forward(self, x):
                self.b = 30
                return self.a + self.b

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = SubModule()
                self.sub2 = SubModule2()
                self.a = 3
                self.b = 4

            def forward(self, x):
                self.b = 20
                return self.sub1(x) + self.a + self.b + self.sub2(x)

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(2, 2)
        output_s = m.forward(input)
        mf = torch.jit.freeze(m)

        # Check if frozen module looks as below:
        # module m {
        #   attributes {
        #     sub2 = ...
        #      b =
        #   }
        #   ...
        #   submodule {
        #     module m {
        #       attributes {
        #         sub2 = ...
        #         b =
        #       }
        #       ...
        #     }
        #   }
        # }
        mf = mf._c
        self.assertFalse(mf.hasattr("sub1"))
        self.assertFalse(mf.hasattr("a"))
        self.assertTrue(mf.hasattr("b"))
        self.assertTrue(mf.hasattr("sub2"))
        self.assertTrue(mf.sub2.hasattr("b"))  # verify b is preserved in sub2
        self.assertFalse(mf.sub2.hasattr("a"))  # verify a is removed in sub2
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_fork(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.ones(20, 20)
                self.b = torch.ones(20, 20)

            def forward(self, x):
                return self.a * self.b + x

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()

            def forward(self, x):
                fut = torch.jit._fork(self.sub.forward, x)
                y_hat = self.sub(x)
                y = torch.jit._wait(fut)
                return y_hat + y

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(20, 20)
        output_s = m.forward(input)
        mf = torch._C._freeze_module(m._c)

        # Check if frozen module looks as below:
        # module m {
        #   attributes {
        #   }
        #   ...
        #   submodule {
        #   }
        # }
        self.assertFalse(mf.hasattr("a"))
        self.assertFalse(mf.hasattr("b"))
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_nested_fork(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.ones(20, 20)
                self.b = torch.ones(20, 20)

            def forward(self, x):
                return self.a * self.b + x

        class SubModule2(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()
                self.c = torch.ones(20, 20)

            def forward(self, x):
                fut = torch.jit._fork(self.sub.forward, x)
                y_hat = self.sub(x)
                y = torch.jit._wait(fut)
                return y_hat + y + self.c

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule2()
                self.d = 1

            def forward(self, x):
                fut = torch.jit._fork(self.sub.forward, x)
                y_hat = self.sub(x)
                y = torch.jit._wait(fut)
                self.d = 2
                return y_hat * y + self.d

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(20, 20)
        output_s = m.forward(input)
        mf = torch._C._freeze_module(m._c)
        # Check if frozen module looks as below:
        # module m {
        #   attributes {
        #   }
        #   ...
        #   submodule {
        #   }
        # }
        self.assertFalse(mf.hasattr("a"))
        self.assertFalse(mf.hasattr("b"))
        self.assertFalse(mf.hasattr("c"))
        self.assertTrue(mf.hasattr("d"))
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_fork2(self):
        @torch.jit.script
        def foo(x):
            return x * 2

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.ones(20, 20)
                self.b = torch.ones(20, 20)

            def forward(self, x):
                fut = torch.jit._fork(foo, self.a)
                y_hat = foo(self.b)
                y = torch.jit._wait(fut)
                return y_hat + y

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(2, 2)
        output_s = m.forward(input)
        mf = torch._C._freeze_module(m._c)

        # Check if frozen module looks as below:
        # module m {
        #   attributes {
        #     self.a = ...
        #     self.b = ..
        #   }
        #   ...
        #   submodule {
        #   }
        # }
        # TODO:  Although there are no mutation, the alias analysis
        # conservatively assumes there is a mutation because attributes are
        # passed to fork subgraph. both 'a' and 'b' are preserved.
        self.assertTrue(mf.hasattr("a"))
        self.assertFalse(mf.hasattr("b"))
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_fork_calling_module_method(self):
        @torch.jit.script
        def foo(x, y):
            return x * y

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.ones(20, 20)
                self.b = torch.ones(20, 20)

            @torch.jit.export
            def foo(self, x):
                return x * self.a

            @torch.jit.export
            def bar(self, x):
                return x * self.b

            def forward(self, x):
                fut = torch.jit._fork(self.foo, self.b)
                y_hat = self.bar(self.a)
                y = torch.jit._wait(fut)
                return y_hat + y

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(2, 2)
        output_s = m.forward(input)
        mf = torch._C._freeze_module(m._c)
        # Check if frozen module looks as below:
        # module m {
        #   attributes {
        #     self.b = ..
        #   }
        #   ...
        # TODO:  Although there are no mutation, the alias analysis
        # conservatively assumes there is a mutation because attributes are
        # passed to fork subgraph. 'b' is preserved.
        self.assertFalse(mf.hasattr("a"))
        self.assertTrue(mf.hasattr("b"))
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_sharedclasstype(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = torch.tensor([2.2])

            def forward(self, x):
                return self.a + self.b

            @torch.jit.export
            def modify_a(self, x):
                self.a[0] += 10
                return self.b

            @torch.jit.export
            def modify_b(self, x):
                self.b[0] += 20
                return self.a

        class SubModule2(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()
                self.b = torch.tensor([3.3])

            def forward(self, x):
                y = self.sub.modify_b(x)
                return y + self.b

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = SubModule()  # sub1 and sub2.sub shared same class type.
                self.sub2 = SubModule2()
                self.a = torch.tensor([4.4])

            def forward(self, x):
                z = self.sub1.modify_a(x)
                return self.sub2(x) + z + self.a

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(2, 2)
        output_s = m.forward(input)
        mf = torch._C._freeze_module(m._c)

        # Checking if  Frozen module looks as  below
        # module mf {
        #   attributes {
        #     sub1 = ...
        #     sub2 = ...
        #   }
        #   ...
        #   submodules {
        #     module sub1 {
        #       attributes {
        #         a = ...
        #         b = ...
        #       }
        #       ...
        #     }
        #     module sub2 {
        #       attributes {
        #         sub = ...
        #       }
        #       ...
        #       submodule {
        #         module sub {
        #           attributes {
        #             a = ...
        #             b = ...
        #           }
        #           ...
        #         }
        #       }
        #     }
        #   }
        # }

        self.assertTrue(mf.hasattr("sub1"))
        self.assertTrue(mf.sub1.hasattr("a"))
        self.assertTrue(mf.sub1.hasattr("b"))
        self.assertFalse(mf.hasattr("a"))
        self.assertTrue(mf.hasattr("sub2"))
        self.assertTrue(mf.sub2.hasattr("sub"))
        self.assertFalse(mf.sub2.hasattr("b"))
        self.assertTrue(mf.sub2.sub.hasattr("a"))
        self.assertTrue(mf.sub2.sub.hasattr("b"))
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_nestedaliasing(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = torch.tensor([2.2])

            def forward(self, x):
                return self.a + self.b

            @torch.jit.export
            def modify_a(self, x):
                self.a[0] = 10
                return self.b

            @torch.jit.export
            def modify_b(self, x):
                self.b[0] = 20
                return self.a

        Sub = SubModule()

        class SubModule2(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = Sub  # aliasing

            def forward(self, x):
                return self.sub.a

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = Sub  # aliasing
                self.sub2 = SubModule2()

            def forward(self, x):
                z = self.sub1.modify_a(x)
                return self.sub2(x) + z

        m = torch.jit.script(TestModule())
        m.eval()
        mf = torch._C._freeze_module(m._c)
        self.assertTrue(mf.hasattr("sub1"))
        self.assertTrue(mf.sub1.hasattr("a"))
        self.assertFalse(mf.sub1.hasattr("b"))
        self.assertTrue(mf.hasattr("sub2"))
        self.assertTrue(mf.sub2.hasattr("sub"))
        self.assertTrue(
            mf.sub2.sub.hasattr("a")
        )  # Freezing detects that self.sub2.sub.a and self.sub1.a are alias
        self.assertFalse(mf.sub2.sub.hasattr("b"))
        input = torch.randn(2, 2)
        output_s = m.forward(input)
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    # FIXME: JIT is not honoring aliasing. 'Sub' module is copied. As a result
    # Eager and Script modules produce different output.
    def test_freeze_module_with_nestedaliasingscalar(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 1.1
                self.b = 2.2

            def forward(self, x):
                return self.a + self.b

            @torch.jit.export
            def modify_a(self, x):
                self.a = 10.0
                return self.b

            @torch.jit.export
            def modify_b(self, x):
                self.b = 20.0
                return self.a

        Sub = SubModule()

        class SubModule2(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = Sub  # aliasing

            def forward(self, x):
                return self.sub.a

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = Sub  # aliasing
                self.sub2 = SubModule2()

            def forward(self, x):
                z = self.sub1.modify_a(x)
                return self.sub2(x) + z

        m = TestModule()
        ms = torch.jit.script(m)
        ms.eval()
        mf = torch._C._freeze_module(ms._c)
        self.assertTrue(mf.hasattr("sub1"))
        self.assertTrue(mf.sub1.hasattr("a"))
        self.assertFalse(mf.sub1.hasattr("b"))
        # sub2 is fully folded becasue self.sub1 and self.sub2.sub are not alias (Scripting bug)
        self.assertFalse(mf.hasattr("sub2"))
        input = torch.randn(2, 2)
        output = m.forward(input)
        output_s = ms.forward(input)
        output_f = mf.forward(input)
        # Should be equal
        self.assertNotEqual(output, output_s)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_preserve_sub_module(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = 2.2

            def forward(self, x):
                return self.a

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = SubModule()  # aliasing
                self.sub2 = SubModule()

            def forward(self, x):
                return self.sub2(x) + self.sub1(x)

        m = TestModule()
        ms = torch.jit.script(m)
        ms.eval()
        mf = torch._C._freeze_module(ms._c, ["sub1"])

        # Test that 'sub1' is preserved entirely and 'sub2' is completely folded
        self.assertTrue(mf.hasattr("sub1"))
        self.assertTrue(mf.sub1.hasattr("a"))
        self.assertTrue(mf.sub1.hasattr("b"))
        self.assertFalse(mf.hasattr("sub2"))
        input = torch.randn(2, 2)
        output_s = ms.forward(input)
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_preserve_sub_module_and_mutation(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = 2.2

            def forward(self, x):
                self.a[0] = 3.3
                return self.a

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = SubModule()  # aliasing
                self.sub2 = SubModule()

            def forward(self, x):
                return self.sub2(x) + self.sub1(x)

        m = TestModule()
        ms = torch.jit.script(m)
        ms.eval()
        mf = torch._C._freeze_module(ms._c, ["sub1"])

        # Test that be both sub1 and sub1 are preserved and 'b' is preserved
        # even if it is not used. To fulfill user request to preserve 'sub1'
        self.assertTrue(mf.hasattr("sub1"))
        self.assertTrue(mf.sub1.hasattr("a"))
        self.assertTrue(mf.sub1.hasattr("b"))
        self.assertTrue(mf.hasattr("sub2"))
        self.assertTrue(mf.sub2.hasattr("a"))
        self.assertTrue(mf.sub2.hasattr("b"))
        input = torch.randn(2, 2)
        output_s = ms.forward(input)
        output_f = mf.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_module_with_helperfunction(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 11
                self.b = 2

            def forward(self, x):
                return self.a + self.b

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()
                self.a = 3
                self.b = 4

            def forward(self, x):
                self.b = 20
                return self._forward(x) + self.a + self.b

            def _forward(self, x):
                return self.sub(x)

        m = torch.jit.script(TestModule())
        m.eval()
        input = torch.randn(2, 2)
        mf = torch._C._freeze_module(m._c)
        self.assertFalse(mf.hasattr("sub"))
        self.assertFalse(mf.hasattr("a"))
        self.assertTrue(mf.hasattr("b"))
        with self.assertRaisesRegex(
            AttributeError, "TestModule (.*) does not have a field with name '_forward'"
        ):
            mf._forward(x)  # noqa: F821

    def test_freeze_module_with_inplace_mutable(self):
        class FreezeMe(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.a = [11, 22]

            @torch.jit.script_method
            def forward(self, x):
                for i in range(3):
                    self.a.append(i)
                return self.a

        m = FreezeMe()
        m.eval()
        m_f = torch._C._freeze_module(m._c)
        self.assertTrue(m_f.hasattr("a"))
        m.forward(torch.tensor([3]))
        out = m_f.forward(torch.tensor([5]))
        expected = [11, 22, 0, 1, 2, 0, 1, 2]
        self.assertEqual(out, expected)

    # Mutable attributes
    def test_freeze_module_with_mutable_list(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = [1, 2]

            def forward(self, x):
                return self.a

        m = FreezeMe()
        m.eval()
        m.a.append(3)
        m_s = torch.jit.script(m)
        v = m_s.a
        v.append(4)
        m_s.a = v
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        # Post-freezing mutating m_s.a  does not affect m_f (m_f has its own copy).
        v = m_s.a
        v.append(5)
        m_s.a = v
        self.assertFalse(m_f.hasattr("a"))
        out = m_f.forward(torch.tensor([5]))
        expected = [1, 2, 3, 4]
        self.assertEqual(out, expected)

    def test_freeze_module_with_mutable_dict(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = {"layer": "4"}

            def forward(self, x):
                return self.a

            @torch.jit.export
            def modify_a(self, x):
                self.a["layer"] = self.a["layer"] + "1"
                return self.a

        m = FreezeMe()
        m.eval()
        m.a["layer2"] = "3"
        m_s = torch.jit.script(m)
        t = torch.tensor(5)
        m_s.modify_a(t)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        m.a["layer2"] += "2"
        m_s.modify_a(t)
        self.assertFalse(m_f.hasattr("a"))
        out = m_f.forward(t)
        expected = {"layer": "411", "layer2": "3"}
        self.assertEqual(out, expected)

    def test_freeze_module_with_mutable_tensor(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.0, 2.0, 3.0])

            def forward(self, x):
                return self.a

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.a[1] += 3.0
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        # Post-freezing tensor attribute mutations affect m_f.
        # FIXME: deep copy all folded attributes so that m_f has full ownership.
        m_s.a[0] += 5.0
        self.assertFalse(m_f.hasattr("a"))
        out = m_f.forward(torch.tensor([5]))
        expected = [6.0, 5.0, 3.0]
        self.assertEqual(out, expected)

    def test_freeze_module_with_tuple(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = (torch.tensor([1, 2, 3, 4, 5, 6]), "hi")

            def forward(self, x):
                if x[0] == 2.0:
                    self.a[0][0] = 10
                return self.a[0].sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([2.0])
        expected = m_s.forward(inp)
        m_s.a[0][0] = 1
        m_f = torch._C._freeze_module(m_s._c)
        self.assertFalse(m_f.hasattr("a"))
        out = m_f.forward(inp)
        self.assertEqual(out, expected)

    def test_freeze_module_with_tensor(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1, 2, 3, 4, 5, 6])

            def forward(self, x):
                x = self.a.view(2, 3)
                x[0][0] += 10
                return self.a.sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([5])
        expected = m_s.forward(inp)
        m_f = torch._C._freeze_module(m_s._c)
        self.assertTrue(m_f.hasattr("a"))
        m_f.a[0] -= 10
        out = m_f.forward(inp)
        self.assertEqual(out, expected)

    def test_freeze_module_with_list(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = [torch.tensor([1, 2, 3, 4, 5, 6])]

            def forward(self, x):
                self.a[0][1] += 10
                return self.a[0].sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([5])
        expected = m_s.forward(inp)
        m_s.a[0][1] -= 10
        m_f = torch._C._freeze_module(m_s._c)
        self.assertFalse(m_f.hasattr("a"))
        out = m_f.forward(inp)
        self.assertEqual(out, expected)

    def test_freeze_module_with_aliased_tensor_attr(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
                self.b = self.a.view(2, 3)

            def forward(self, x):
                self.b[1] += 10
                return self.a.sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        self.assertTrue(m_f.hasattr("a"))
        inp = torch.tensor([5])
        out = m_f.forward(inp)
        expected = torch.tensor(51)  # 1+2+3+14+15+16
        self.assertEqual(out, expected)

    def test_freeze_module_with_aliased_tensor_attr2(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
                self.b = {"layer": ([self.a.view(2, 3), torch.tensor([10])], 20)}
                self.c = ([self.a.view(2, 3), torch.tensor([10])], 20)
                self.d = (self.a.view(2, 3), 20)

            def forward(self, x):
                self.d[0][0] += 10
                return self.a.sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([5])
        expected = m_s.forward(inp)
        with self.assertRaisesRegex(
            RuntimeError, "module contains attributes values that overlaps"
        ):
            m_f = torch._C._freeze_module(m_s._c)

    def test_freeze_module_with_aliased_tensor_attr3(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
                self.b = [self.a, torch.tensor([10])]

            def forward(self, x):
                self.a[1] += 10
                return self.b[0].sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([5])
        expected = m_s.forward(inp)
        m_f = torch._C._freeze_module(m_s._c)
        self.assertTrue(m_f.hasattr("a"))
        self.assertTrue(m_f.hasattr("b"))
        out = m_f.forward(inp)
        expected += 10  # account for  self.a += 10.
        self.assertEqual(out, expected)

    def test_freeze_module_with_aliased_tensor_attr4(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1, 2, 3, 4, 5, 6])
                self.b = [self.a, torch.tensor([10])]

            def forward(self, x):
                self.b[0][0] += 10
                return self.a.sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([5])
        expected = m_s.forward(inp)
        m_s.a[0] -= 10
        with self.assertRaisesRegex(
            RuntimeError, "module contains attributes values that overlaps"
        ):
            m_f = torch._C._freeze_module(m_s._c)

    def test_freeze_module_with_overlapping_attrs(self):
        a = torch.tensor([1, 2, 3, 4, 5, 6])

        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = [a.view(3, 2), torch.tensor([10])]
                self.c = (20, a.view(2, 3))

            def forward(self, x):
                self.b[0][0] += 10
                return self.c[1].sum()

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        inp = torch.tensor([5])
        expected = m_s.forward(inp)
        a[0] -= 10
        with self.assertRaisesRegex(
            RuntimeError, "module contains attributes values that overlaps"
        ):
            m_f = torch._C._freeze_module(m_s._c)

    def test_freeze_module_with_aliased_attr(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = [1, 2, 3, 4, 5, 6]
                self.b = self.a
                self.c = (self.a, 10)

            def forward(self, x):
                self.b[1] += 10
                return str(self.a) + str(self.c)

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        # FIXME: It should be assertTrue. Currently scripting is making a copy for setting self.b (see #33034)
        self.assertFalse(m_f.hasattr("a"))
        self.assertFalse(m_f.hasattr("c"))
        inp = torch.tensor([5])
        out = m_f.forward(inp)
        expected = m_s.forward(inp)
        self.assertEqual(out, expected)

    # Check attribute a is preserved. Alias analysis detects that 'a' has output writers.
    # In this example, 'a' is not mutated. However, we do not track which sub
    # values of a composite ivalue is mutated.
    def test_freeze_module_with_aliased_attr2(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = [1, 2, 3, 4, 5, 6]
                self.b = ([11], [10])

            def forward(self, x):
                v = self.a
                self.b = (v, [12])
                v2 = self.b[1]
                v2.append(7)
                return str(v) + str(v2)

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        self.assertTrue(m_f.hasattr("a"))
        inp = torch.tensor([5])
        out = m_f.forward(inp)
        expected = m.forward(inp)
        self.assertEqual(out, expected)

    def test_freeze_module_with_aliased_attr3(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = [1, 2, 3, 4, 5, 6]
                self.b = ([11], [10])

            def forward(self, x):
                v = self.a
                v2 = (v, [12])
                v3 = v2[0]
                v3.append(7)
                return str(self.a)

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        self.assertTrue(m_f.hasattr("a"))
        inp = torch.tensor([5])
        out = m_f.forward(inp)
        expected = m.forward(inp)
        self.assertEqual(out, expected)

    def test_freeze_module_return_self(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.0, 2.0, 3.0])

            def forward(self, x):
                return self

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        with self.assertRaisesRegex(
            RuntimeError, "attempted to freeze a module that return itself"
        ):
            m_f = torch._C._freeze_module(m_s._c)

    def test_freeze_module_inlining(self):
        @torch.jit.script  # noqa: B903
        class Obj:  # noqa: B903
            def __init__(self, x: int, y: int):
                self.x = x
                self.y = y

        class Mod(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.obj = Obj(2, 3)

            def forward(self, i: int):
                print(self.obj)
                return i

        mod = torch.jit.freeze(torch.jit.script(Mod().eval()))
        obj = mod.graph.findNode("prim::Constant")
        self.assertTrue(torch._C._jit_object_is_non_holding(obj))

        buffer = io.BytesIO()
        torch.jit.save(mod, buffer)
        buffer.seek(0)

        loaded = torch.jit.load(buffer)
        obj = mod.graph.findNode("prim::Constant")
        self.assertTrue(torch._C._jit_object_is_non_holding(obj))

    def test_freeze_module_return_sub_module(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 32, 3, 1)

            def forward(self, x):
                return self.conv1

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c)
        self.assertTrue(m_f.hasattr("conv1"))

    def test_freeze_module_no_forward(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin = nn.Linear(10, 1)

            @torch.jit.export
            def foo(self, x):
                return self.lin(x)

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch._C._freeze_module(m_s._c, preservedAttrs=["foo"])
        input = torch.ones(10)
        self.assertEqual(m_s.foo(input), m_f.foo(input))

    def test_freeze_no_forward(self):
        class FreezeMe(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.lin = nn.Linear(10, 1)

            @torch.jit.export
            def foo(self, x):
                return self.lin(x)

        m = FreezeMe()
        m_s = torch.jit.script(m)
        m_s.eval()
        m_f = torch.jit.freeze(m_s, preserved_attrs=["foo"])
        input = torch.ones(10)
        self.assertEqual(m_s.foo(input), m_f.foo(input))

    def test_freeze_module_in_training_mode(self):
        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 32, 3, 1)
                self.conv2 = nn.Conv2d(32, 64, 3, 1)
                self.dropout1 = nn.Dropout2d(0.25)
                self.dropout2 = nn.Dropout2d(0.5)
                self.fc1 = nn.Linear(9216, 128)
                self.fc2 = nn.Linear(128, 10)

            def forward(self, x):
                x = self.conv1(x)
                x = nn.functional.relu(x)
                x = self.conv2(x)
                x = nn.functional.max_pool2d(x, 2)
                x = self.dropout1(x)
                x = torch.flatten(x, 1)
                x = self.fc1(x)
                x = nn.functional.relu(x)
                x = self.dropout2(x)
                x = self.fc2(x)
                output = nn.functional.log_softmax(x, dim=1)
                return output

        model = torch.jit.script(Net())
        model.train()
        mTrain_freezed = torch._C._freeze_module(model._c)
        # verify mTrain_freezed looks exactly as:
        # module {
        #   attributes {
        #     conv1 = ...
        #     conv2 = ...
        #     dropout1 = ...
        #     dropout2 = ...
        #     fc1 = ...
        #     fc2 = ...
        #   }
        #   ...
        #   submodules {
        #     module conv1 {
        #       attributes {
        #          weight = ...
        #          bias = ...
        #       }
        #       ...
        #     }
        #     module conv2 {
        #       attributes {
        #          weight = ...
        #          bias = ...
        #       }
        #       ...
        #     }
        #     module dropout1 {
        #       attributes {
        #          training = ...
        #       }
        #       ...
        #     }
        #     module dropout2 {
        #       attributes {
        #          training = ...
        #       }
        #       ...
        #     }
        #     module fc1 {
        #       attributes {
        #          weight = ...
        #          bias = ...
        #       }
        #       ...
        #     }
        #     module fc2 {
        #       attributes {
        #          weight = ...
        #          bias = ...
        #       }
        #       ...
        #     }
        self.assertFalse(mTrain_freezed.hasattr("training"))
        self.assertTrue(mTrain_freezed.hasattr("conv1"))
        self.assertFalse(mTrain_freezed.conv1.hasattr("training"))
        self.assertTrue(mTrain_freezed.conv1.hasattr("weight"))
        self.assertTrue(mTrain_freezed.conv1.hasattr("bias"))
        self.assertTrue(mTrain_freezed.hasattr("conv2"))
        self.assertFalse(mTrain_freezed.conv2.hasattr("training"))
        self.assertTrue(mTrain_freezed.conv2.hasattr("weight"))
        self.assertTrue(mTrain_freezed.conv2.hasattr("bias"))
        self.assertTrue(mTrain_freezed.hasattr("dropout1"))
        self.assertTrue(mTrain_freezed.dropout1.hasattr("training"))
        self.assertTrue(mTrain_freezed.hasattr("dropout2"))
        self.assertTrue(mTrain_freezed.dropout2.hasattr("training"))
        self.assertTrue(mTrain_freezed.hasattr("fc1"))
        self.assertTrue(mTrain_freezed.fc1.hasattr("weight"))
        self.assertTrue(mTrain_freezed.fc1.hasattr("bias"))
        self.assertTrue(mTrain_freezed.hasattr("fc2"))
        self.assertTrue(mTrain_freezed.fc2.hasattr("weight"))
        self.assertTrue(mTrain_freezed.fc2.hasattr("bias"))
        model.eval()
        mEval_freezed = torch._C._freeze_module(model._c)
        self.assertFalse(mEval_freezed.hasattr("conv1"))
        self.assertFalse(mEval_freezed.hasattr("conv2"))
        self.assertFalse(mEval_freezed.hasattr("dropout1"))
        self.assertFalse(mEval_freezed.hasattr("training"))
        self.assertFalse(mEval_freezed.hasattr("fc1"))
        self.assertFalse(mEval_freezed.hasattr("dropout2"))
        self.assertFalse(mEval_freezed.hasattr("fc2"))
        with self.assertRaisesRegex(
            AttributeError, "does not have a field with name 'state_dict'"
        ):
            print(mEval_freezed.state_dict())
        buffer = io.BytesIO()
        torch.jit.save(mEval_freezed, buffer)
        buffer.seek(0)
        m = torch.jit.load(buffer)
        FileCheck().check_not("GetAttr[name=").run(m._c._get_method("forward").graph)
        m2 = torch._C._freeze_module(model._c, preserveParameters=True)
        self.assertTrue(m2.hasattr("conv1"))
        self.assertTrue(m2.hasattr("conv2"))
        self.assertFalse(m2.hasattr("dropout1"))
        self.assertFalse(m2.hasattr("training"))
        self.assertTrue(m2.hasattr("fc1"))
        self.assertFalse(m2.hasattr("dropout2"))
        self.assertTrue(m2.hasattr("fc2"))

    def test_freeze_module_detach_gradient(self):
        mod = nn.Conv2d(8, 3, 4, 2, 1)
        self.assertTrue(mod.weight.requires_grad)
        smod = torch.jit.script(mod)
        smod.eval()
        fmod = torch._C._freeze_module(smod._c)
        self.assertTrue(mod.weight.requires_grad)
        self.assertTrue(smod.weight.requires_grad)
        self.assertFalse(fmod.hasattr("weight"))
        inp = torch.ones(1, 8, 32, 32)
        out1 = fmod.forward(inp)
        # FIXME: frozen module mutated from outside (original module).
        with torch.no_grad():
            smod.weight[0, 0, 0, 0] += 100.0
        out2 = fmod.forward(inp)
        out3 = smod(inp)
        self.assertNotEqual(out1, out2)
        self.assertEqual(out2, out3)

    def test_freeze_module_with_user_preserved_attr(self):
        class Module(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = torch.tensor([2.2])

            def forward(self, x):
                return self.a + self.b

        m = torch.jit.script(Module())
        m.eval()
        fm = torch._C._freeze_module(m._c, ["a"])
        # Attribute "a" is preserved
        self.assertTrue(fm.hasattr("a"))
        self.assertFalse(fm.hasattr("b"))

    def test_freeze_module_with_user_preserved_method(self):
        class Module(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = torch.tensor([2.2])

            def forward(self, x):
                return self.a + self.b

            @torch.jit.export
            def modify_a(self, x):
                self.a[0] += 10
                return self.b

            @torch.jit.export
            def modify_b(self, x):
                self.b[0] += 20
                return self.a

        m = torch.jit.script(Module())
        m.eval()
        fm = torch._C._freeze_module(m._c, ["modify_a"])
        # Both attribute "a" and method "modify_a" are preserved
        self.assertTrue(fm.hasattr("a"))
        self.assertFalse(fm.hasattr("b"))
        input = torch.randn(2, 2)
        expected = m.forward(input)
        out = fm.forward(input)
        self.assertEqual(out, expected)

    def test_freeze_module_with_user_preserved_method2(self):
        class Module(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor([1.1])
                self.b = torch.tensor([2.2])

            def forward(self, x):
                self.b += 10
                return self.a + self.b

            @torch.jit.export
            def modify_a(self, x):
                self.a[0] += 10
                return self.b + self.a

        m = torch.jit.script(Module())
        m.eval()
        fm = torch._C._freeze_module(m._c, ["modify_a"])
        FileCheck().check('prim::GetAttr[name="a"]').run(fm.forward.graph)
        FileCheck().check('prim::GetAttr[name="b"]').run(fm.modify_a.graph)

    def test_freeze_module_with_user_preserved_attribute_on_submodule(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 1
                self.b = 2

            def forward(self):
                return self.a + self.b

        class Module(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub1 = SubModule()
                self.sub2 = SubModule()

            def forward(self):
                return self.sub1() + self.sub2()

        m = torch.jit.script(Module())
        m.eval()
        m = torch.jit.freeze(m, preserved_attrs=["sub1.a", "sub2.a"])
        fm = m._c

        self.assertTrue(fm.hasattr("sub1"))
        self.assertTrue(fm.sub1.hasattr("a"))
        self.assertFalse(fm.sub1.hasattr("b"))
        self.assertTrue(fm.hasattr("sub2"))
        self.assertTrue(fm.sub2.hasattr("a"))
        self.assertFalse(fm.sub2.hasattr("b"))
        self.assertEqual(m(), 6)
        m.sub1.a += 1
        self.assertEqual(m(), 7)

    def test_freeze_module_with_user_preserved_attribute_on_unused_submodule(self):
        class SubModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 1
                self.b = 2

            def forward(self):
                return self.a + self.b

            @torch.jit.export
            def method_a(self):
                return 42

        class Module(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()

            def forward(self):
                return 1

        m = torch.jit.script(Module())
        m.eval()
        fm = torch.jit.freeze(m, preserved_attrs=["sub.a", "sub.method_a"])._c

        self.assertTrue(fm.hasattr("sub"))
        self.assertTrue(fm.sub.hasattr("a"))
        self.assertFalse(fm.sub.hasattr("b"))
        self.assertTrue(fm.sub._has_method("method_a"))

    def test_freeze_module_with_user_preserved_method_on_submodule(self):
        class SubModule(nn.Module):
            def forward(self, x):
                return self.method_a(x) + self.method_b(x)

            def method_a(self, x):
                return x * x

            def method_b(self, x):
                return x + x

        class Module(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()

            def forward(self, x):
                return self.sub(x)

        m = torch.jit.script(Module())
        m.eval()
        fm = torch.jit.freeze(m, preserved_attrs=["sub.method_a"])._c

        self.assertTrue(fm.hasattr("sub"))
        self.assertTrue(fm.sub._has_method("method_a"))
        self.assertFalse(fm.sub._has_method("method_b"))

    @skipIfNoFBGEMM
    def test_module_with_shared_type_instances(self):
        class Child(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)

            def forward(self, x):
                x = self.conv1(x)
                return x

        class Parent(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.quant = torch.ao.quantization.QuantStub()
                self.conv1 = nn.Conv2d(1, 1, 1).to(dtype=torch.float32)
                self.child = Child()
                self.child2 = Child()
                self.dequant = torch.ao.quantization.DeQuantStub()

            def forward(self, x):
                x = self.quant(x)
                x = self.conv1(x)
                x = self.child(x)
                x = self.child2(x)
                x = self.dequant(x)
                return x

        def _static_quant(model):
            qModel = torch.ao.quantization.QuantWrapper(model)
            qModel.qconfig = torch.ao.quantization.default_qconfig
            torch.ao.quantization.prepare(qModel, inplace=True)
            qModel(torch.rand(4, 1, 4, 4, dtype=torch.float32))
            torch.ao.quantization.convert(qModel, inplace=True)
            return model

        with override_quantized_engine("fbgemm"):
            data = torch.randn(4, 1, 4, 4, dtype=torch.float32)
            m = Parent().to(torch.float32)
            m = _static_quant(m)
            m = torch.jit.script(m)
            m.eval()
            torch._C._jit_pass_inline(m.graph)
            m_frozen = wrap_cpp_module(torch._C._freeze_module(m._c))
            # Earlier bug resulted in _packed_params set to false.
            FileCheck().check_not("_packed_params = False").run(
                m_frozen._c.dump_to_str(True, True, False)
            )

            m_res = m(data)
            # It used to segfault while running frozen module.
            m_frozen_res = m_frozen(data)
            self.assertEqual(m_res, m_frozen_res)

    def test_module_getattr_indirection(self):
        @torch.jit.script
        class ValHolder:
            def __init__(self, val: int):
                self.val: int = val

        class Mod(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.mod1 = ValHolder(1)
                self.mod2 = ValHolder(2)

            def forward(self, cond: bool):
                if cond:
                    mod = self.mod1
                else:
                    mod = self.mod2
                return mod.val

        mod = Mod()
        mod.eval()
        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
        mod_eager = Mod()
        self.assertEqual(mod_eager(True), frozen_mod(True))
        self.assertEqual(mod_eager(False), frozen_mod(False))

    def test_freeze_module_with_non_static_module_container_index(self):
        """
        Test that Modules containing non-static ModuleDict or ModuleList
        indexing cannot be frozen.
        """

        @torch.jit.interface
        class ModuleInterface(torch.nn.Module):
            def forward(self, inp: Any) -> Any:
                pass

        class ImplementsInterface(torch.nn.Module):
            def forward(self, inp: Any) -> Any:
                if isinstance(inp, torch.Tensor):
                    return torch.max(inp, dim=0)

                return inp

        class ModWithDict(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.d = torch.nn.ModuleDict({"module": ImplementsInterface()})

            def forward(self, x: torch.Tensor, key: str) -> Any:
                value: ModuleInterface = self.d[key]
                return value.forward(x)

        m = torch.jit.script(ModWithDict())
        m.eval()
        with self.assertRaisesRegex(
            RuntimeError,
            "Freezing modules containing prim::ModuleContainerIndex is not supported",
        ):
            mf = torch._C._freeze_module(m._c)

        class ModWithList(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.l = torch.nn.ModuleList([ImplementsInterface()])

            def forward(self, x: torch.Tensor, idx: int) -> Any:
                value: ModuleInterface = self.l[idx]
                return value.forward(x)

        m = torch.jit.script(ModWithList())
        m.eval()
        with self.assertRaisesRegex(
            RuntimeError,
            "Freezing modules containing prim::ModuleContainerIndex is not supported",
        ):
            mf = torch._C._freeze_module(m._c)

    def test_freeze_with_interface_mutable(self):
        @torch.jit.interface
        class ModuleInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class ImplementsInterface(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sum = torch.zeros((2, 2))

            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                self.sum += inp.relu()
                return self.sum

        class WrapperModule(torch.nn.Module):
            impl: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.impl = ImplementsInterface()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.impl.forward(x)

        m = torch.jit.script(WrapperModule())
        m.eval()
        m_frozen = torch.jit.freeze(m)

        x = torch.rand((2, 2))

        m_frozen(x)
        self.assertEqual(m_frozen.impl.sum, x.relu())

    def test_freeze_with_swapping_interfaces(self):
        @torch.jit.interface
        class ModuleInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class Implementation1(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                return inp.relu()

        class Implementation2(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                return inp.sin()

        class WrapperModule(torch.nn.Module):
            impl: ModuleInterface

            def __init__(self) -> None:
                super().__init__()
                self.option1 = Implementation1()
                self.option2 = Implementation2()
                self.impl = self.option1
                self.idx = 0

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                self.idx += 1
                if self.idx % 2 == 1:
                    self.impl = self.option1
                else:
                    self.impl = self.option2
                return self.impl(x)

        m = torch.jit.script(WrapperModule())
        m.eval()
        with self.assertRaisesRegex(
            RuntimeError, "Freezing does not support SetAttr on an interface type"
        ):
            m_frozen = torch.jit.freeze(m)

    def test_freeze_recursive_interfaces(self):
        @torch.jit.interface
        class InnerInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        @torch.jit.interface
        class OuterInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class InnerImpl(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.ones((2, 2))

            def forward(self, inp):
                return inp.cos() * self.x

        class OuterImpl(torch.nn.Module):
            inner_impl: InnerInterface

            def __init__(self) -> None:
                super().__init__()
                self.inner_impl = InnerImpl()

            def forward(self, inp):
                return inp.relu() + self.inner_impl(inp.sin())

        class WrapperModule(torch.nn.Module):
            outer_impl: OuterInterface

            def __init__(self) -> None:
                super().__init__()
                self.outer_impl = OuterImpl()

            def forward(self, inp):
                return self.outer_impl(inp) + inp

        m = WrapperModule()
        x = torch.rand((2, 2))
        expected = m(x)

        m_s = torch.jit.script(m)
        m_s.eval()
        m_s = torch.jit.freeze(m_s)
        actual = m_s(x)

        self.assertEqual(expected, actual)

    def test_freeze_recursive_interfaces_with_reassignment(self):
        @torch.jit.interface
        class InnerInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        @torch.jit.interface
        class OuterInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class InnerImpl1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.ones((2, 2))

            def forward(self, inp):
                return inp.cos() * self.x

        class InnerImpl2(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.ones((2, 2)) * 2

            def forward(self, inp):
                return inp.sin() / self.x

        class OuterImpl(torch.nn.Module):
            inner_impl: InnerInterface

            def __init__(self) -> None:
                super().__init__()
                self.inner_impl = InnerImpl1()
                self.impl1 = InnerImpl1()
                self.impl2 = InnerImpl1()
                self.idx = 0

            def forward(self, inp):
                self.idx += 1
                if self.idx % 2 == 0:
                    self.inner_impl = self.impl1
                else:
                    self.inner_impl = self.impl2
                return inp.relu() + self.inner_impl(inp.sin())

        class WrapperModule(torch.nn.Module):
            outer_impl: OuterInterface

            def __init__(self) -> None:
                super().__init__()
                self.outer_impl = OuterImpl()

            def forward(self, inp):
                return self.outer_impl(inp) + inp

        m = WrapperModule()

        m_s = torch.jit.script(m)
        m_s.eval()
        with self.assertRaisesRegex(
            RuntimeError, "Freezing does not support SetAttr on an interface type"
        ):
            m_s = torch.jit.freeze(m_s)

    def test_freeze_interface_swapping_two_methods(self):
        @torch.jit.interface
        class MyInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class Impl1(torch.nn.Module):
            def forward(self, inp):
                return inp.cos()

        class Impl2(torch.nn.Module):
            def forward(self, inp):
                return inp.sin()

        class WrapperModule1(torch.nn.Module):
            interface_impl: MyInterface

            def __init__(self) -> None:
                super().__init__()
                self.interface_impl = Impl1()
                self.impl1 = Impl1()
                self.impl2 = Impl2()
                self.idx = 0

            def forward(self, x):
                return self.interface_impl(x)

            @torch.jit.export
            def other_method(self, x):
                self.idx += 1
                if self.idx % 2 == 0:
                    self.interface_impl = self.impl1
                else:
                    self.interface_impl = self.impl2
                return self.interface_impl(x)

        class WrapperModule2(torch.nn.Module):
            interface_impl: MyInterface

            def __init__(self) -> None:
                super().__init__()
                self.interface_impl = Impl1()
                self.impl1 = Impl1()
                self.impl2 = Impl2()
                self.idx = 0

            def forward(self, x):
                self.idx += 1
                if self.idx % 2 == 0:
                    self.interface_impl = self.impl1
                else:
                    self.interface_impl = self.impl2
                return self.interface_impl(x)

            @torch.jit.export
            def other_method(self, x):
                return self.interface_impl(x)

        m1 = torch.jit.script(WrapperModule1())
        m2 = torch.jit.script(WrapperModule2())

        m1.eval()
        m2.eval()

        with self.assertRaisesRegex(
            RuntimeError, "Freezing does not support SetAttr on an interface type"
        ):
            torch.jit.freeze(m1, preserved_attrs=["other_method"])

        with self.assertRaisesRegex(
            RuntimeError, "Freezing does not support SetAttr on an interface type"
        ):
            torch.jit.freeze(m2, preserved_attrs=["other_method"])

    def test_freeze_recursive_interfaces_same_name(self):
        @torch.jit.interface
        class InnerInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        @torch.jit.interface
        class OuterInterface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class InnerImpl(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.ones((2, 2))

            def forward(self, inp):
                return inp.cos() * self.x

        class OuterImpl(torch.nn.Module):
            impl: InnerInterface

            def __init__(self) -> None:
                super().__init__()
                self.impl = InnerImpl()
                self.x = torch.ones((2, 2)) * 5

            def forward(self, inp):
                return self.other_method(inp)

            def other_method(self, inp):
                return inp.relu() + self.impl(inp.sin()) + self.x

        class WrapperModule(torch.nn.Module):
            impl: OuterInterface

            def __init__(self) -> None:
                super().__init__()
                self.impl = OuterImpl()

            def forward(self, inp):
                return self.impl(inp) + inp

        m = WrapperModule()
        x = torch.rand((2, 2))
        expected = m(x)

        m_s = torch.jit.script(m)
        m_s.eval()
        m_s = torch.jit.freeze(m_s)
        actual = m_s(x)

        self.assertEqual(expected, actual)

    def test_freeze_non_interface_module_swap(self):
        class InnerModule(torch.nn.Module):
            def __init__(self, x):
                super().__init__()
                self.x = x

            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                return inp.relu() + self.x

        class WrapperModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.option1 = InnerModule(torch.rand((2, 2)))
                self.option2 = InnerModule(torch.rand((2, 2)))
                self.impl = self.option1
                self.idx = 0

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                self.idx += 1
                if self.idx % 2 == 1:
                    self.impl = self.option1
                else:
                    self.impl = self.option2
                return self.impl(x)

        unfrozen = WrapperModule()
        m = torch.jit.script(unfrozen)
        m.eval()
        m_frozen = torch.jit.freeze(m)

        x = torch.rand((2, 2))
        expected = unfrozen(x)
        actual = m_frozen(x)
        self.assertEqual(expected, actual)

    @unittest.expectedFailure
    def test_freeze_interface_within_object(self):
        # I don't think there's any way to create a plain python object that
        # contains a torch.nn.Module inside it, but just in case... I'm not
        # sure freezing would handle this case correctly, so marking as xfail
        # so that if this ever _does_ start working someone will need to
        # investigate to make sure this is handled correctly.
        class MyIface(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                pass

        class MyImpl(torch.nn.Module):
            def forward(self, inp: torch.Tensor) -> torch.Tensor:
                return inp.sin()

        class MyObject:
            impl: MyIface

            def run(self, x):
                return self.impl(x)

        class WrapperModule(torch.nn.Module):
            impl: MyObject

            def __init__(self) -> None:
                super().__init__()
                self.impl = MyObject()
                self.impl.impl = MyImpl()

            def forward(self, x: torch.Tensor) -> torch.Tensor:
                return self.impl(x)

        unfrozen = WrapperModule()
        m = torch.jit.script(unfrozen)
        m.eval()
        m_frozen = torch.jit.freeze(m)

        x = torch.rand((2, 2))
        expected = unfrozen(x)
        actual = m_frozen(x)
        self.expectEqual(expected, actual)

    def test_freeze_non_module_class_getattr(self):
        class BoxCoder:
            def __init__(self, bbox_xform_clip):
                # type: (float) -> None
                self.bbox_xform_clip = bbox_xform_clip

            def decode(self, input):
                return input * self.bbox_xform_clip

        class MyModule(torch.nn.Module):
            __annotations__ = {
                "box_coder": BoxCoder,
            }

            def __init__(self) -> None:
                super().__init__()
                self.box_coder = BoxCoder(50.0)

            def forward(self, input):
                return self.box_coder.decode(input)

        model = MyModule()
        model.eval()
        script_model = torch.jit.freeze(torch.jit.script(model))
        inp = torch.randn([4, 4])
        output_eager = model(inp)
        self.assertEqual(model(inp), script_model(inp))
        FileCheck().check_not("GetAttr").run(script_model.graph)

    def test_freeze_module_with_tupleoutput_submodule(self):
        class SubModule(nn.Module):
            def forward(self, x):
                return (x + 1, x + 2)

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sub = SubModule()

            def forward(self, x):
                y1, y2 = self.sub(x)
                return y1 + y2

        m = torch.jit.script(TestModule())
        m = m.eval()
        mf = torch.jit.freeze(m)
        inp = torch.randn(2, 2)
        expected = m.forward(inp)
        output = mf.forward(inp)
        # Check if prim::TupleConstruct and prim::TupleUnpack
        # Don't exist in frozen graph
        FileCheck().check_not("prim::TupleConstruct").run(mf.graph)
        FileCheck().check_not("prim::TupleUnpack").run(mf.graph)
        self.assertEqual(output, expected)

    def test_freeze_module_with_call_method(self):
        class Mod(nn.Module):
            def __init__(self, val):
                super().__init__()
                self.param = nn.Parameter(val)

            def forward(self, x):
                # this method will change during freezing
                return x + self.param

            @torch.jit.export
            def make_prediction(self, x):
                y = x + x
                return self.forward(y)

        param = torch.rand([2, 2])
        x = torch.rand([2, 2])

        unscripted_mod = Mod(param)
        mod = torch.jit.script(unscripted_mod)
        mod.eval()
        mod = torch.jit.freeze(mod, preserved_attrs=["make_prediction"])

        self.assertEqual(
            mod.forward(x), unscripted_mod.forward(x), atol=1e-5, rtol=1e-5
        )


@skipIfTorchDynamo("somehow causing hanging during python shutdown")
class TestFrozenOptimizations(JitTestCase):
    def setUp(self):
        super().setUp()
        self.default_dtype = torch.get_default_dtype()
        torch.set_default_dtype(torch.double)

    def tearDown(self):
        torch.set_default_dtype(self.default_dtype)
        super().tearDown()

    def test_conv_bn_folding(self):
        conv_bias = [True, False]
        module_pairs = [
            (nn.Conv1d, nn.BatchNorm1d),
            (nn.Conv2d, nn.BatchNorm2d),
            (nn.Conv3d, nn.BatchNorm3d),
        ]
        use_tracing = [True, False]
        bn_running_stats = [True, False]

        for use_bias, modules, tracing, track_stats in product(
            conv_bias, module_pairs, use_tracing, bn_running_stats
        ):

            class ConvBN(torch.nn.Module):
                def __init__(self, in_channels, out_channels, **kwargs):
                    super().__init__()
                    self.conv = modules[0](
                        in_channels, out_channels, bias=use_bias, **kwargs
                    )
                    self.bn = modules[1](
                        out_channels, eps=0.001, track_running_stats=track_stats
                    )

                def forward(self, x):
                    x = self.conv(x)
                    return self.bn(x)

            mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
            inps = [4, 3, 4]
            if modules[0] == nn.Conv2d:
                inps.append(inps[-1])
            if modules[0] == nn.Conv3d:
                inps.append(inps[-1])
                inps.append(inps[-1])

            inp = torch.rand(inps)

            if tracing:
                scripted_mod = torch.jit.trace(mod_eager, (inp))
            else:
                scripted_mod = torch.jit.script(mod_eager)

            self.run_pass("inline", scripted_mod.graph)
            self.run_pass("peephole", scripted_mod.graph)
            self.run_pass("constant_propagation", scripted_mod.graph)

            FileCheck().check("conv").check("batch").run(scripted_mod.graph)
            # successfully no-ops with non-const inputs
            self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
            FileCheck().check("conv").check("aten::batch_norm").run(scripted_mod.graph)

            scripted_mod = torch.jit.freeze(scripted_mod)
            self.run_pass("fold_frozen_conv_bn", scripted_mod.graph)
            if track_stats:
                FileCheck().check("conv").check_not("aten::batch_norm").run(
                    scripted_mod.graph
                )
            else:
                FileCheck().check("conv").check("aten::batch_norm").run(
                    scripted_mod.graph
                )

            self.assertEqual(mod_eager(inp), scripted_mod(inp))
            self.assertEqual(mod_eager(inp), scripted_mod(inp))

    def test_conv_bn_folding_not_forward(self):
        class ConvBN(torch.nn.Module):
            def __init__(self, in_channels, out_channels, **kwargs):
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    in_channels, out_channels, bias=True, **kwargs
                )
                self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
                self.amt = 3.2

            def forward(self, x):
                x = self.conv(x)
                return self.bn(x)

            @torch.jit.export
            def make_prediction(self, x):
                return self.forward(x) + self.amt

        mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).eval()
        scripted_mod = torch.jit.script(mod_eager)
        torch._C._jit_pass_inline(scripted_mod.make_prediction.graph)
        FileCheck().check("conv").check("aten::batch_norm").run(
            scripted_mod.make_prediction.graph
        )

        # _jit_pass_optimize_frozen_graph should not be called on non-method attributes (e.g. "amt")
        scripted_mod = torch.jit.freeze(
            scripted_mod, preserved_attrs=["make_prediction", "amt"]
        )
        FileCheck().check("conv").check_not("aten::batch_norm").run(
            scripted_mod.make_prediction.graph
        )

    # During freezing this creates tensors constants that are attached to the frozen graph,
    # which is then kept alive by the compilation unit (which causes a leak)
    @skipCUDAMemoryLeakCheckIf(True)
    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
    def test_conv_bn_folding_autocast_scenario_cuda(self):
        # CUDA conv takes input tensors which must all be the same dtype,
        # which can cause issues if folding produces inputs of different dtypes.

        class ConvBN(torch.nn.Module):
            def __init__(self, in_channels, out_channels, **kwargs):
                super().__init__()
                self.conv = torch.nn.Conv2d(
                    in_channels, out_channels, bias=False, dtype=torch.half, **kwargs
                )
                self.bn = torch.nn.BatchNorm2d(
                    out_channels, eps=0.001, dtype=torch.float
                )

            def forward(self, x):
                return self.bn(self.conv(x))

        mod_eager = ConvBN(3, 32, kernel_size=3, stride=2).cuda().eval()
        scripted_mod = torch.jit.script(mod_eager)
        scripted_mod = torch.jit.freeze(scripted_mod)
        FileCheck().check("conv").check_not("aten::batch_norm").run(scripted_mod.graph)
        conv_node = scripted_mod.graph.findNode("aten::conv2d", True)
        self.assertTrue(conv_node is not None)
        bias_input = conv_node.namedInput("bias")
        self.assertTrue(bias_input is not None)
        self.assertTrue(bias_input.type().dtype() == torch.half)

        x = torch.rand((3, 3, 32, 32), dtype=torch.half).cuda()

        self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
        self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)

    def test_conv_add_folding(self):
        @torch.no_grad()
        def test_conv_fusion(
            use_bias, module, tracing, op, scalar, add_tensor, expect_success
        ):
            class ConvOp(torch.nn.Module):
                __constants__ = ["use_scalar"]

                def __init__(self, in_channels, out_channels, tensor=None, **kwargs):
                    super().__init__()
                    self.conv = module(
                        in_channels, out_channels, bias=use_bias, **kwargs
                    )
                    self.conv2 = module(
                        in_channels, out_channels, bias=use_bias, **kwargs
                    )
                    self.use_scalar = scalar
                    tensor_size = [1 for _ in range(self.conv.weight.ndim)]
                    tensor_size[1] = self.conv.weight.size(0)
                    self.tensor = (
                        add_tensor
                        if add_tensor is not None
                        else torch.rand(tensor_size)
                    )
                    self.op = op

                def forward(self, x):
                    x = self.conv(x)
                    if self.use_scalar:
                        return self.op(x, 2.0)
                    else:
                        return self.op(x, self.tensor)

            mod_eager = ConvOp(3, 32, kernel_size=3, stride=2).eval()

            inps = [4, 3, 4]
            if module == nn.Conv2d:
                inps.append(inps[-1])
            if module == nn.Conv3d:
                inps.append(inps[-1])
                inps.append(inps[-1])

            inp = torch.rand(inps)

            if tracing:
                scripted_mod = torch.jit.trace(mod_eager, (inp,))
            else:
                scripted_mod = torch.jit.script(mod_eager)

            self.run_pass("inline", scripted_mod.graph)
            op_str = "aten::" + op.__name__

            FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
            # successively no-ops with non-const inputs
            self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
            self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)
            FileCheck().check("conv").check(op_str).run(scripted_mod.graph)
            scripted_mod = torch.jit.freeze(scripted_mod)
            self.run_pass("fold_frozen_conv_mul_or_div", scripted_mod.graph)
            self.run_pass("fold_frozen_conv_add_or_sub", scripted_mod.graph)

            if expect_success:
                FileCheck().check("conv").check_not(op_str).run(scripted_mod.graph)
            else:
                FileCheck().check("conv").check(op_str).run(scripted_mod.graph)

            self.assertEqual(mod_eager(inp), scripted_mod(inp))
            self.assertEqual(mod_eager(inp), scripted_mod(inp))

        conv_bias = [True, False]
        modules = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
        use_tracing = [False, True]
        use_scalar = [False, True]
        ops = [torch.add, torch.sub, torch.mul, torch.div]

        for use_bias, module, tracing, pytorch_op, scalar in product(
            conv_bias, modules, use_tracing, ops, use_scalar
        ):
            test_conv_fusion(
                use_bias,
                module,
                tracing,
                pytorch_op,
                scalar,
                add_tensor=None,
                expect_success=True,
            )

        for use_bias, pytorch_op in product(conv_bias, ops):
            # broadcasting add
            test_conv_fusion(
                use_bias,
                nn.Conv2d,
                False,
                pytorch_op,
                False,
                add_tensor=torch.rand(32, 1, 32),
                expect_success=False,
            )

            # broadcasting add
            test_conv_fusion(
                use_bias,
                nn.Conv2d,
                False,
                pytorch_op,
                False,
                add_tensor=torch.rand(1, 1),
                expect_success=True,
            )

            # add with different dtype
            test_conv_fusion(
                use_bias,
                nn.Conv2d,
                False,
                pytorch_op,
                False,
                add_tensor=torch.tensor([2]).to(torch.int),
                expect_success=True,
            )

    def test_conv_mul_add_bn(self):
        class Conv_Mul_Add_Bn(nn.Module):
            def __init__(self, in_channels, out_channels, **kwargs):
                super().__init__()
                self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)
                self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
                self.tensor1 = torch.tensor(2.2)
                self.tensor2 = torch.tensor(2)

            def forward(self, x):
                return self.bn(
                    torch.add(torch.mul(self.conv(x), self.tensor1), self.tensor2)
                )

        input = torch.randn(8, 3, 64, 64)
        model = Conv_Mul_Add_Bn(3, 32, kernel_size=3, stride=1).eval()

        with torch.no_grad():
            result = model(input)
            traced_model = torch.jit.trace(model, input).eval()
            traced_model = torch.jit.freeze(traced_model)
            tresult = traced_model(input)
            self.assertEqual(result, tresult)
            FileCheck().check("conv").check_not("aten::batch_norm").run(
                traced_model.graph
            )
            FileCheck().check("conv").check_not("aten::add").run(traced_model.graph)

    def test_linear_bn_folding(self):
        module_pairs = [
            (nn.Linear, nn.BatchNorm1d),
            (nn.Linear, nn.BatchNorm2d),
            (nn.Linear, nn.BatchNorm3d),
        ]
        use_tracing = [True, False]
        bn_running_stats = [True, False]

        for modules, tracing, track_stats in product(
            module_pairs, use_tracing, bn_running_stats
        ):

            class LinearBN(torch.nn.Module):
                def __init__(self, in_features, out_features):
                    super().__init__()
                    self.linear = modules[0](in_features, out_features)
                    self.bn = modules[1](
                        out_features, eps=0.001, track_running_stats=track_stats
                    )

                def forward(self, x):
                    x = self.linear(x)
                    return self.bn(x)

            mod_eager = LinearBN(32, 32).eval()

            inps = [3, 32]
            if modules[1] == nn.BatchNorm2d:
                inps.append(inps[-1])
                inps.append(inps[-1])
            if modules[1] == nn.BatchNorm3d:
                inps.append(inps[-1])
                inps.append(inps[-1])
                inps.append(inps[-1])

            inp = torch.rand(inps)

            if tracing:
                scripted_mod = torch.jit.trace(mod_eager, (inp))
            else:
                scripted_mod = torch.jit.script(mod_eager)

            self.run_pass("inline", scripted_mod.graph)
            self.run_pass("peephole", scripted_mod.graph)
            self.run_pass("constant_propagation", scripted_mod.graph)

            FileCheck().check("linear").check("batch").run(scripted_mod.graph)
            # successfully no-ops with non-const inputs
            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
            FileCheck().check("linear").check("aten::batch_norm").run(
                scripted_mod.graph
            )

            scripted_mod = torch.jit.freeze(scripted_mod)
            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
            if track_stats:
                FileCheck().check("linear").check_not("aten::batch_norm").run(
                    scripted_mod.graph
                )
            else:
                FileCheck().check("linear").check("aten::batch_norm").run(
                    scripted_mod.graph
                )

            self.assertEqual(mod_eager(inp), scripted_mod(inp))
            self.assertEqual(mod_eager(inp), scripted_mod(inp))

    def test_bn_not_broadcast_with_linear(self):
        module_pairs = [
            (nn.Linear, nn.BatchNorm1d),
            (nn.Linear, nn.BatchNorm2d),
            (nn.Linear, nn.BatchNorm3d),
        ]
        use_tracing = [True, False]
        linear_in = 3
        # (linear_out, bn_in)
        # case 1: linear_out < bn_in
        # case 2: linear_out > bn_in
        # case 3: linear_out != bn_in && linear_out = 1
        dims = [(2, 4), (4, 2), (1, 2)]

        for modules, tracing, dim in product(module_pairs, use_tracing, dims):
            linear_out, bn_in = dim[0], dim[1]

            linear = modules[0](linear_in, linear_out)
            bn = modules[1](bn_in)
            mod_eager = nn.Sequential(linear, bn).eval()

            N, C = 3, bn_in
            input_shape = [N, C]
            if modules[1] == nn.BatchNorm1d:
                H = linear_in
                input_shape.append(H)
            elif modules[1] == nn.BatchNorm2d:
                H, W = 4, linear_in
                input_shape.append(H)
                input_shape.append(W)
            elif modules[1] == nn.BatchNorm3d:
                D, H, W = 4, 4, linear_in
                input_shape.append(D)
                input_shape.append(H)
                input_shape.append(W)

            inp = torch.rand(input_shape)

            if tracing:
                scripted_mod = torch.jit.trace(mod_eager, (inp))
            else:
                scripted_mod = torch.jit.script(mod_eager)

            self.run_pass("inline", scripted_mod.graph)
            self.run_pass("peephole", scripted_mod.graph)
            self.run_pass("constant_propagation", scripted_mod.graph)

            FileCheck().check("linear").check("batch").run(scripted_mod.graph)
            self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
            FileCheck().check("linear").check("aten::batch_norm").run(
                scripted_mod.graph
            )

            frozen_mod = torch.jit.freeze(scripted_mod)
            self.run_pass("fold_frozen_linear_bn", frozen_mod.graph)
            # successfully skipped folding
            FileCheck().check("linear").check("aten::batch_norm").run(frozen_mod.graph)

            self.assertEqual(mod_eager(inp), frozen_mod(inp))
            self.assertEqual(mod_eager(inp), frozen_mod(inp))

            # successfully failed folding
            with self.assertRaisesRegex(
                AssertionError,
                "To fuse, linear.out_features == bn.num_features or bn.num_features == 1",
            ):
                nn.utils.fusion.fuse_linear_bn_eval(linear, bn)

    @skipCUDAMemoryLeakCheckIf(True)
    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
    def test_linear_bn_folding_autocast_scenario_cuda(self):
        module_pairs = [
            (nn.Linear, nn.BatchNorm1d),
            (nn.Linear, nn.BatchNorm2d),
            (nn.Linear, nn.BatchNorm3d),
        ]
        use_tracing = [True, False]
        bn_running_stats = [True, False]

        for modules, tracing, track_stats in product(
            module_pairs, use_tracing, bn_running_stats
        ):

            class LinearBN(torch.nn.Module):
                def __init__(self, in_features, out_features):
                    super().__init__()
                    self.linear = modules[0](
                        in_features, out_features, bias=False, dtype=torch.half
                    )
                    self.bn = modules[1](out_features, eps=0.001, dtype=torch.float)

                def forward(self, x):
                    x = self.linear(x)
                    return self.bn(x)

            mod_eager = LinearBN(32, 32).cuda().eval()

            inps = [3, 32]
            if modules[1] == nn.BatchNorm2d:
                inps.append(inps[-1])
                inps.append(inps[-1])
            if modules[1] == nn.BatchNorm3d:
                inps.append(inps[-1])
                inps.append(inps[-1])
                inps.append(inps[-1])

            x = torch.rand(inps, dtype=torch.half).cuda()

            if tracing:
                scripted_mod = torch.jit.trace(mod_eager, (x))
            else:
                scripted_mod = torch.jit.script(mod_eager)
            scripted_mod = torch.jit.freeze(scripted_mod)
            FileCheck().check("linear").check_not("aten::batch_norm").run(
                scripted_mod.graph
            )
            lin_node = scripted_mod.graph.findNode("aten::linear", True)
            self.assertTrue(lin_node is not None)
            weight_input = lin_node.namedInput("weight")
            bias_input = lin_node.namedInput("bias")
            self.assertTrue(bias_input is not None)
            self.assertTrue(weight_input.type().dtype() == torch.half)
            self.assertTrue(bias_input.type().dtype() == torch.half)

            self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
            self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)

    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
    def test_linear_concat(self):
        out_dimms = [[5, 10], [1, 5]]

        for w1_dim, w2_dim in out_dimms:

            class ModMultLinear(nn.Module):
                def __init__(self, w1_dim, w2_dim):
                    super().__init__()
                    self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
                    self.b1 = nn.Parameter(torch.rand([w1_dim]))
                    self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
                    self.b2 = nn.Parameter(torch.rand([w2_dim]))

                def forward(self, in_tensor1):
                    res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
                    res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2)
                    return res1, res2

            mod_eager = ModMultLinear(w1_dim, w2_dim).eval()

            test_val1 = torch.rand([50, 5])
            self.check_linear_optimizations(mod_eager, 2, 1, (test_val1,))

    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
    def test_linear_concat_complex(self):
        """
        Testing that the interleaving of multiple optimizations does not
        cause errors, and gets optimized as expected
        """

        class ModMultLinear(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                w1_dim = 5
                w2_dim = 10
                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
                self.b1 = nn.Parameter(torch.rand([w1_dim]))
                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
                self.b2 = nn.Parameter(torch.rand([w2_dim]))

            def forward(self, in_tensor1):
                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
                res3 = torch._C._nn.linear(res1, self.w2, self.b2)
                res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b2)
                res4 = torch._C._nn.linear(res1, self.w1, self.b1)
                return res2, res3, res4

        mod_eager = ModMultLinear().eval()
        test_val1 = torch.rand([50, 5])
        self.check_linear_optimizations(mod_eager, 4, 2, (test_val1,))

    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
    def test_linear_concat_different_input(self):
        """
        There should be no change to the graph due to the optimization pass
        due to the two input tensors being different
        """

        # Freezing requires that the graph be a module
        class ModMultLinear(nn.Module):
            def __init__(self, w1_dim, w2_dim):
                super().__init__()
                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
                self.b1 = nn.Parameter(torch.rand([w1_dim]))
                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
                self.b2 = nn.Parameter(torch.rand([w2_dim]))

            def forward(self, in_tensor1, in_tensor2):
                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
                res2 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
                return res1, res2

        mod_eager = ModMultLinear(5, 5).eval()
        test_val1 = torch.rand([50, 5])
        test_val2 = torch.rand([50, 5])
        self.check_linear_optimizations(mod_eager, 2, 2, (test_val1, test_val2))

    @unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
    def test_linear_multiple_blocks(self):
        class ModMultLinear(nn.Module):
            def __init__(self, w1_dim, w2_dim):
                super().__init__()
                self.w1 = nn.Parameter(torch.rand([w1_dim, 5]))
                self.b1 = nn.Parameter(torch.rand([w1_dim]))
                self.w2 = nn.Parameter(torch.rand([w2_dim, 5]))
                self.b2 = nn.Parameter(torch.rand([w2_dim]))

            def forward(self, in_tensor1, in_tensor2, cond: bool):
                res1 = torch._C._nn.linear(in_tensor1, self.w1, self.b1)
                if cond:
                    res3 = torch._C._nn.linear(in_tensor2, self.w2, self.b2)
                    res4 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
                else:
                    raise AssertionError
                res2 = torch._C._nn.linear(in_tensor1, self.w2, self.b1)
                return res1, res2, res3, res4

        mod_eager = ModMultLinear(5, 5).eval()
        test_val1 = torch.rand([50, 5])
        test_val2 = torch.rand([50, 5])
        self.check_linear_optimizations(mod_eager, 4, 3, (test_val1, test_val2, True))

    def check_linear_optimizations(
        self, eager_mod, orig_linears, new_linears, test_vals
    ):
        for is_cuda in [False, True]:
            if is_cuda:
                mod_to_device = eager_mod.cuda()
                test_vals_to_device = [
                    t.cuda() if isinstance(t, torch.Tensor) else t for t in test_vals
                ]
            else:
                mod_to_device = eager_mod
                test_vals_to_device = test_vals

            script_mod = torch.jit.script(mod_to_device)
            op_graph = script_mod.graph

            FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
                op_graph
            )
            # successively no-ops with non-const inputs
            self.run_pass("concat_frozen_linear", op_graph)
            FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
                op_graph
            )

            script_mod = torch.jit.freeze(script_mod)
            op_graph = script_mod.graph
            self.run_pass("concat_frozen_linear", op_graph)
            if is_cuda:
                FileCheck().check_count("aten::linear", new_linears, exactly=True).run(
                    op_graph
                )
            else:
                FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
                    op_graph
                )

            self.assertEqual(
                mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)
            )

    def test_optimize_freeze_module(self):
        in_channels, out_channels = 3, 32
        conv = torch.nn.Conv2d(
            in_channels, out_channels, kernel_size=3, stride=2, bias=True
        )
        bn = torch.nn.BatchNorm2d(out_channels, eps=0.001)
        mod = torch.nn.Sequential(conv, bn)
        # set optimize to False here, by default freezing runs run_frozen_optimizations
        frozen_mod = torch.jit.freeze(
            torch.jit.script(mod.eval()), optimize_numerics=False
        )
        # inspect frozen mod
        FileCheck().check("batch_norm").run(frozen_mod.graph)
        torch.jit.run_frozen_optimizations(frozen_mod)
        FileCheck().check_not("batch_norm").run(frozen_mod.graph)

        # run_frozen_optimizations should be run
        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
        FileCheck().check_not("batch_norm").run(frozen_mod.graph)

    def test_freeze_remove_dropout(self):
        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.dropout = nn.Dropout(0.5)

            def forward(self, x):
                return self.dropout(x)

        mod = torch.jit.script(Net())
        # inspect mod
        torch._C._jit_pass_inline(mod.graph)
        FileCheck().check("aten::dropout").run(mod.graph)
        frozen_mod = torch.jit.freeze(mod.eval())
        FileCheck().check_not("aten::dropout").run(frozen_mod.graph)

        input = torch.randn(2)
        output_s = mod.forward(input)
        output_f = frozen_mod.forward(input)
        self.assertEqual(output_s, output_f)

    def test_freeze_remove_feature_dropout(self):
        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.dropout = nn.Dropout2d(0.5)

            def forward(self, x):
                return self.dropout(x)

        mod = torch.jit.script(Net().eval())
        # inspect mod
        torch._C._jit_pass_inline(mod.graph)
        FileCheck().check("aten::feature_dropout").run(mod.graph)
        frozen_mod = torch.jit.freeze(mod)
        FileCheck().check_not("aten::feature_dropout").run(frozen_mod.graph)

        input = torch.randn(2, 2, 1, 1)
        output_s = mod.forward(input)
        output_f = frozen_mod.forward(input)
        self.assertEqual(output_s, output_f)

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_freeze_mkdlnn(self):
        conv = torch.nn.Conv2d(3, 32, kernel_size=3, stride=2).eval().float()
        convmkl = mkldnn_utils.to_mkldnn(conv)
        out = torch.jit.freeze(torch.jit.script(convmkl.eval()))
        inp = torch.rand([4, 3, 4, 4]).float()
        self.assertEqual(out(inp.to_mkldnn()).to_dense(), conv(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_conv_to_mkldnn(self):
        with set_default_dtype(torch.float):
            for module, trace in product([nn.Conv2d, nn.Conv3d], [False, True]):
                mod = module(3, 32, kernel_size=3, stride=2).eval()
                inps = [4, 3, 4]
                if module == nn.Conv2d:
                    inps.append(inps[-1])
                if module == nn.Conv3d:
                    inps.append(inps[-1])
                    inps.append(inps[-1])

                inp = torch.rand(inps)
                if trace:
                    scripted_mod = torch.jit.script(mod)
                else:
                    scripted_mod = torch.jit.trace(mod, (inp,))

                self.run_pass("inline", scripted_mod.graph)

                FileCheck().check("conv").run(scripted_mod.graph)
                # successfully no-ops with non-const inputs
                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
                FileCheck().check_not("to_mkldnn").run(scripted_mod.graph)

                scripted_mod = torch.jit.freeze(scripted_mod)
                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
                FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check(
                    "to_dense"
                ).run(scripted_mod.graph)

                self.assertEqual(mod(inp), scripted_mod(inp))
                self.assertEqual(mod(inp), scripted_mod(inp))

    def test_linear_transpose(self):
        class ModLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = torch.nn.Parameter(torch.rand(30))
                self.weight = torch.nn.Parameter(torch.rand([30, 20]))

            def forward(self, x):
                return torch._C._nn.linear(x, self.weight, self.bias)

        mod_eager = ModLinear().eval()
        test_val = torch.rand([50, 20])
        self.check_linear_optimizations_2(
            mod_eager, 1, 0, "transpose_frozen_linear", (test_val,)
        )

    def test_linear_non_constant_weight(self):
        class ModLinear(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bias = torch.nn.Parameter(torch.rand(30))

            def forward(self, x, weight):
                return torch._C._nn.linear(x, weight, self.bias)

        mod_eager = ModLinear().eval()
        test_val = torch.rand([50, 20])
        test_weight = torch.rand([30, 20])
        self.check_linear_optimizations_2(
            mod_eager, 1, 1, "transpose_frozen_linear", (test_val, test_weight)
        )

    def check_linear_optimizations_2(
        self, eager_mod, orig_linears, new_linears, opt_pass, test_vals
    ):
        # TODO: merge with check_linear_optimizations once both diffs land
        mod_to_device = eager_mod
        test_vals_to_device = test_vals

        script_mod = torch.jit.script(mod_to_device)
        op_graph = script_mod.graph

        FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
            op_graph
        )
        # successively no-ops with non-const inputs
        self.run_pass(opt_pass, op_graph)
        FileCheck().check_count("aten::linear", orig_linears, exactly=True).run(
            op_graph
        )

        script_mod = torch.jit.freeze(script_mod)
        op_graph = script_mod.graph
        self.run_pass(opt_pass, op_graph)
        FileCheck().check_count("aten::linear", new_linears, exactly=True).run(op_graph)

        self.assertEqual(
            mod_to_device(*test_vals_to_device), script_mod(*test_vals_to_device)
        )

    @staticmethod
    def conv():
        # Generic composable conv for testing purposes
        return nn.Conv2d(8, 8, 1)

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_collapse_adjacent_conversions(self):
        with set_default_dtype(torch.float):
            mod = nn.Sequential(self.conv(), self.conv()).eval()
            scripted_mod = torch.jit.script(mod)
            scripted_mod = torch.jit.freeze(scripted_mod)
            self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
            FileCheck().check("to_mkldnn").check("prim::mkldnn_convolution").check(
                "prim::mkldnn_convolution"
            ).check("to_dense").run(scripted_mod.graph)
            FileCheck().check_count("to_mkldnn", 1, exactly=True).run(
                scripted_mod.graph
            )

            inp = torch.rand([1, 8, 8, 8])
            self.assertEqual(scripted_mod(inp), mod(inp))
            self.assertEqual(scripted_mod(inp), mod(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_mkldnn_fuser_broadcasting(self):
        class Add(nn.Module):
            def __init__(self, tensor):
                super().__init__()
                self.tensor = tensor

            def forward(self, x):
                return x + self.tensor

        with set_default_dtype(torch.float):
            for add_inp in [8], [8, 8, 1]:
                mod = nn.Sequential(self.conv(), Add(torch.rand(add_inp))).eval()
                scripted_mod = torch.jit.script(mod)
                scripted_mod = torch.jit.freeze(scripted_mod)
                self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
                FileCheck().check("prim::BroadcastMKLDNNTensors").run(
                    scripted_mod.graph
                )
                inp = torch.rand([1, 8, 8, 8])
                self.assertEqual(scripted_mod(inp), mod(inp))
                self.assertEqual(scripted_mod(inp), mod(inp))

                # for good measure, check that broadcasting does not work without this op
                # so we can remove the op if it ever gets supported
                with self.assertRaisesRegex(RuntimeError, ""):
                    (
                        torch.rand([1, 8, 8, 8]).to_mkldnn()
                        + torch.rand(add_inp).to_mkldnn()
                    )

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_mkldnn_inplace_removal(self):
        class AddMul(nn.Module):
            def __init__(self, tensor):
                super().__init__()
                self.tensor = tensor

            def forward(self, x):
                return x.add_(self.tensor).div_(self.tensor) - 4

        with set_default_dtype(torch.float):
            mod = nn.Sequential(self.conv(), AddMul(torch.rand([8]))).eval()
            scripted_mod = torch.jit.script(mod)
            scripted_mod = torch.jit.freeze(scripted_mod)
            self.run_pass("convert_frozen_ops_to_mkldnn", scripted_mod.graph)
            # add gets uninplaced and reinplaced
            FileCheck().check("aten::to_mkldnn").check("aten::add_").check(
                "aten::div_"
            ).run(scripted_mod.graph)
            inp = torch.rand([1, 8, 8, 8])
            self.assertEqual(scripted_mod(inp), mod(inp))
            self.assertEqual(scripted_mod(inp), mod(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    @skipIfNoTorchVision
    def test_maxpool_mkldnn(self):
        with set_default_dtype(torch.float):
            model = torchvision.models.resnet18()
            sub_model = torch.nn.Sequential(
                model.conv1, model.bn1, model.relu, model.maxpool
            )
            mod = torch.jit.freeze(torch.jit.script(sub_model.eval()))
            (
                N,
                C,
                H,
                W,
            ) = (
                10,
                3,
                224,
                224,
            )
            inp = torch.randn(N, C, H, W)
            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
            FileCheck().check("max_pool").check("to_dense").run(mod.graph)
            FileCheck().check_count("to_dense", 1, exactly=True).run(mod.graph)
            self.assertEqual(mod(inp), sub_model(inp))

    @unittest.skipIf(torch.backends.mkldnn.is_available(), "Testing no mkldnn")
    def test_conv_to_mkldnn_no_mkldnn(self):
        # test no error when mkldnn not available
        with set_default_dtype(torch.float):
            mod = torch.jit.script(nn.Conv2d(3, 32, kernel_size=3, stride=2).eval())
            frozen = torch.jit.freeze(mod)
            self.run_pass("convert_frozen_ops_to_mkldnn", frozen.graph)
            inp = torch.rand([4, 3, 4, 4])
            self.assertEqual(frozen(inp), mod(inp))

    @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN")
    def test_freeze_conv_relu_fusion(self):
        with set_default_dtype(torch.float):
            conv_bias = [True, False]
            conv_ops = [nn.Conv2d, nn.Conv3d]
            use_add_z = [True, False]
            use_tracing = [True, False]
            for use_bias, conv, add_z, tracing in product(
                conv_bias, conv_ops, use_add_z, use_tracing
            ):

                class Net(nn.Module):
                    def __init__(self, in_channels, out_channels, **kwargs):
                        super().__init__()
                        self.conv = conv(
                            in_channels, out_channels, bias=use_bias, **kwargs
                        )
                        self.relu = nn.ReLU(inplace=True)
                        self.add_z = add_z

                    def forward(self, x):
                        z = self.conv(x)
                        out = self.conv(x)
                        if self.add_z:
                            out += z
                        out = self.relu(out)
                        return out

                mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()

                inps = [5, 3, 4, 4]
                if conv == nn.Conv3d:
                    inps.append(inps[-1])
                inp = torch.rand(inps).cuda()

                if tracing:
                    scripted_mod = torch.jit.trace(mod_eager, (inp))
                else:
                    scripted_mod = torch.jit.script(mod_eager)

                frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
                if TEST_WITH_ROCM:
                    if add_z:
                        FileCheck().check("aten::miopen_convolution_add_relu").run(
                            frozen_mod.graph
                        )
                    else:
                        FileCheck().check("aten::miopen_convolution_relu").run(
                            frozen_mod.graph
                        )
                else:
                    if add_z:
                        FileCheck().check("aten::cudnn_convolution_add_relu").run(
                            frozen_mod.graph
                        )
                    else:
                        FileCheck().check("aten::cudnn_convolution_relu").run(
                            frozen_mod.graph
                        )

                self.assertEqual(mod_eager(inp), frozen_mod(inp))

    @unittest.skipIf(not (TEST_CUDNN or TEST_WITH_ROCM), "requires CUDNN")
    def test_freeze_conv_relu_fusion_not_forward(self):
        with set_default_dtype(torch.float):

            class Net(nn.Module):
                def __init__(self, in_channels, out_channels, **kwargs):
                    super().__init__()
                    self.conv = nn.Conv2d(
                        in_channels, out_channels, bias=None, **kwargs
                    )
                    self.relu = nn.ReLU(inplace=True)

                def forward(self, x):
                    z = self.conv(x)
                    out = self.conv(x)
                    out = self.relu(out)
                    return out

                @torch.jit.export
                def make_prediction(self, x):
                    return self.forward(x)

            mod_eager = Net(3, 6, kernel_size=3, stride=2).eval().cuda()

            inps = [5, 3, 4, 4]
            inp = torch.rand(inps).cuda()

            scripted_mod = torch.jit.script(mod_eager)

            frozen_mod = torch.jit.freeze(
                scripted_mod, preserved_attrs=["make_prediction"]
            )
            optimized_mod = torch.jit.optimize_for_inference(
                frozen_mod, other_methods=["make_prediction"]
            )
            if TEST_WITH_ROCM:
                FileCheck().check("aten::miopen_convolution_relu").run(
                    optimized_mod.make_prediction.graph
                )
            else:
                FileCheck().check("aten::cudnn_convolution_relu").run(
                    optimized_mod.make_prediction.graph
                )

            self.assertEqual(
                mod_eager.make_prediction(inp), optimized_mod.make_prediction(inp)
            )

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_numel_less_than_size_with_padding(self):
        with set_default_dtype(torch.float):

            class MyModule(nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.conv1 = nn.Conv2d(
                        1,
                        2,
                        kernel_size=(2, 4),
                        stride=2,
                        padding=2,
                        dilation=(2, 1),
                    )

                def forward(self, i0):
                    x = self.conv1(i0)
                    o0 = torch.max(x, i0)
                    o1 = torch.clip(x, -1.5, 1.5)
                    return o0, o1

            i0 = torch.zeros((1, 1, 1, 2), dtype=torch.float32)
            mod = MyModule()
            out = mod(i0)

            exported = torch.jit.trace(mod, [i0])
            exported = torch.jit.optimize_for_inference(exported)

            eout = exported(i0)
            self.assertTrue(all(torch.allclose(x, y) for x, y in zip(out, eout)))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_incompatible_perf_formats(self):
        with set_default_dtype(torch.float):

            class Mod(nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.conv = torch.nn.Conv2d(3, 64, 3, 2)
                    self.max_pool = torch.nn.MaxPool2d(111, 111)

                def forward(self, x):
                    a = self.conv(x)
                    b = self.max_pool(a)
                    return a + b

            model = Mod()
            model.eval()
            mod = torch.jit.freeze(torch.jit.script(model))
            (
                N,
                C,
                H,
                W,
            ) = (
                10,
                3,
                224,
                224,
            )
            inp = torch.randn(N, C, H, W)
            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
            self.assertEqual(model(inp), mod(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_pool2d_batchnorm(self):
        with set_default_dtype(torch.float):
            pooling_layers = [
                torch.nn.AdaptiveAvgPool2d(4),
                # torch.nn.AdaptiveMaxPool2d(4), # return tuples
                torch.nn.MaxPool2d(4),
                torch.nn.AvgPool2d(4),
                torch.nn.BatchNorm2d(64).eval(),
            ]

            for pl in pooling_layers:
                sub_model = torch.nn.Sequential(
                    torch.nn.Conv2d(3, 64, 2, 2),
                    torch.nn.ReLU(),
                    pl,
                    torch.nn.Hardswish(),
                )
                sub_model.eval()
                mod = torch.jit.freeze(torch.jit.script(sub_model))
                (
                    N,
                    C,
                    H,
                    W,
                ) = (
                    10,
                    3,
                    224,
                    224,
                )
                inp = torch.randn(N, C, H, W)
                # these two passes needed to remove
                # a size check in BatchNorm2d
                removeExceptions(mod.graph)
                self.run_pass("dce", mod.graph)
                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
                self.assertEqual(sub_model(inp), mod(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_pool3d_batchnorm(self):
        with set_default_dtype(torch.float):
            pooling_layers = [
                torch.nn.MaxPool3d(4),
                # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings
                # torch.nn.AdaptiveMaxPool3d(4), # return tuples
                torch.nn.AvgPool3d(4),
                torch.nn.BatchNorm3d(64).eval(),
            ]

            for pl in pooling_layers:
                sub_model = torch.nn.Sequential(
                    torch.nn.Conv3d(3, 64, 2, 2),
                    torch.nn.ReLU(),
                    pl,
                    torch.nn.Hardswish(),
                )
                sub_model.eval()
                mod = torch.jit.freeze(torch.jit.script(sub_model))
                N, C, H, W, D = 10, 3, 64, 64, 64
                inp = torch.randn(N, C, D, H, W)
                # these two passes needed to remove
                # a size check in BatchNorm2d
                removeExceptions(mod.graph)
                self.run_pass("dce", mod.graph)
                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
                self.assertEqual(sub_model(inp), mod(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    @skipIfNoTorchVision
    def test_conv_hardswish(self):
        with set_default_dtype(torch.float):

            class Clamp(torch.nn.Module):
                def __init__(self, min_val, max_val, **kwargs):
                    super().__init__()
                    self.min_val = min_val
                    self.max_val = max_val

                def forward(self, x):
                    return torch.clamp(x, self.min_val, self.max_val)

            (
                N,
                C,
                H,
                W,
            ) = (
                10,
                3,
                224,
                224,
            )
            activations = [
                torch.nn.Hardswish(),
                torch.nn.Hardsigmoid(),
                torch.nn.ReLU6(),
                torch.nn.Tanh(),
                torch.nn.Hardtanh(0.0, 6.0),
                torch.nn.Hardtanh(1.0, 100.0),
                torch.nn.Hardtanh(-100.0, -1.0),
                torch.nn.GELU(),
                Clamp(-100.0, -1.0),
                Clamp(1.0, 100.0),
                Clamp(0.0, 6.0),
                Clamp(-1.0, 0.0),
            ]

            model = torchvision.models.resnet18()
            for activation in activations:
                sub_model = torch.nn.Sequential(model.conv1, activation)
                sub_model.eval()
                mod = torch.jit.freeze(torch.jit.script(sub_model))
                inp = torch.randn(N, C, H, W)
                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
                FileCheck().check_count("aten::to_dense", 1, exactly=True).run(
                    mod.graph
                )
                self.assertEqual(sub_model(inp), mod(inp))

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_hardswish_hardsigmoid(self):
        with set_default_dtype(torch.float):
            op_map = {
                "prim::MKLDNNHardSwish": F.hardswish,
                "prim::MKLDNNHardSigmoid": F.hardsigmoid,
            }

            input_sizes = ([0], [1], [3], [1, 3, 8, 8])
            for mkldnn_opname, aten_op in op_map.items():
                for size in input_sizes:
                    for inplace in (True, False):
                        inplace_str = "_" if inplace else ""
                        inplace_tgt = "%34" if inplace else "%35"
                        graph_str = f"""graph(%input.1 : Tensor):
                            %33 : None = prim::Constant()
                            %34 : Tensor = aten::to_mkldnn(%input.1, %33)
                            %35 : Tensor = {mkldnn_opname}{inplace_str}(%34)
                            return ({inplace_tgt})
                        """
                        g = torch._C.parse_ir(graph_str)
                        m = self.createFunctionFromGraph(g)
                        x = torch.rand(size)
                        # `inplace=False` is intentional, otherwise we modify the input
                        # and we aren't testing aten impls anyways
                        self.assertEqual(aten_op(x, inplace=False), m(x).to_dense())

    @unittest.skipIf(
        not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled"
    )
    def test_scalar_mul(self):
        with set_default_dtype(torch.float):

            class Mod(nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.mod = nn.Conv2d(8, 8, 1, padding=1)

                def forward(self, x):
                    a1 = self.mod(x) * 4
                    return a1 * 4 + a1 * 5.0

            mod = Mod().eval()
            scripted = torch.jit.freeze(torch.jit.script(mod))
            optimized = torch.jit.optimize_for_inference(scripted)
            inp = torch.rand([1, 8, 8, 8])
            # a1 cant be inplaced for first use, can for second
            FileCheck().check("ScalarMul(").check("ScalarMul_").run(optimized.graph)
            self.assertEqual(optimized(inp), mod(inp))

    def test_remove_detach(self):
        class Mod(nn.Module):
            def forward(self, x):
                y = x.detach()
                return y * y

        mod = Mod().eval()
        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
        inp = torch.randn((2, 2))
        FileCheck().check_not("aten::detach").run(frozen_mod.graph)
        self.assertEqual(frozen_mod(inp), mod(inp))

    def test_remove_detach_not_applied(self):
        class Mod(nn.Module):
            def forward(self, x):
                y = x.detach()
                return x is y

        mod = Mod().eval()
        frozen_mod = torch.jit.freeze(torch.jit.script(mod))
        inp = torch.randn((2, 2))
        FileCheck().check("aten::detach").run(frozen_mod.graph)
        self.assertEqual(frozen_mod(inp), mod(inp))


@skipIfTorchDynamo("somehow causing hanging during python shutdown")
@unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKL-DNN build is disabled")
class TestMKLDNNReinplacing(JitTestCase):
    def setUp(self):
        super().setUp()
        self.default_dtype = torch.get_default_dtype()
        torch.set_default_dtype(torch.float)

    def tearDown(self):
        super().tearDown()
        torch.set_default_dtype(self.default_dtype)

    def getConv(self):
        return nn.Conv2d(3, 32, kernel_size=3, stride=2).eval()

    def getInput(self):
        return torch.rand([4, 3, 4, 4])

    def freezeAndConvert(self, mod):
        mod = torch.jit.freeze(torch.jit.script(mod.eval()))
        self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
        return mod

    def checkResults(self, mod1, mod2):
        inp = self.getInput()
        self.assertEqual(mod1(inp), mod2(inp))

    def test_successful(self):
        # simple conv-relu

        mod_eager = nn.Sequential(self.getConv(), nn.Hardswish(), nn.ReLU())
        mod = self.freezeAndConvert(mod_eager)
        FileCheck().check("mkldnn_convolution").check_next(
            "prim::MKLDNNHardSwish_"
        ).check_next("aten::relu_").run(mod.graph)
        self.checkResults(mod_eager, mod)

    def test_merge_liveness(self):
        class Mod(nn.Module):
            def __init__(self, tensor):
                super().__init__()
                self.tensor = tensor

            def forward(self, x):
                # this mul can be inplaced since x is dead after this use
                temporary = x * self.tensor
                # temporary livespan is the return node,
                # add can not be inplaced
                return temporary + temporary, temporary

        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
        mod = self.freezeAndConvert(mod_eager)
        FileCheck().check("aten::mul_").check_not("aten::add_").run(mod.graph)
        self.checkResults(mod_eager, mod)

    def test_always_alive_values(self):
        class Mod(nn.Module):
            def __init__(self, tensor):
                super().__init__()
                self.tensor = tensor

            def forward(self, x):
                # x can't be inplaced because its a return value,
                # check that the inplacing pass doesnt try to inplace
                # self.tensor because its always alive
                return x * self.tensor, x

        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
        mod = self.freezeAndConvert(mod_eager)
        FileCheck().check_not("aten::mul_").run(mod.graph)
        self.checkResults(mod_eager, mod)

        conv = self.getConv()

        class Mod(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.tensor = torch.rand([4, 32, 1, 1])
                self.conv = conv

            def forward(self, x):
                # the shapes dont add up on this just testing a particular pattern
                conv_output = self.conv(x)
                return conv_output, self.conv(torch.add(x, x))

        mod = self.freezeAndConvert(Mod())
        # x is an input to the graph, and so it should not be inplaced
        # in the torch.add(x, x) call
        FileCheck().check_not("aten::add_").run(mod.graph)

    def test_switch_inputs_to_inplace(self):
        class Mod(nn.Module):
            def __init__(self, tensor):
                super().__init__()
                self.tensor = tensor

            def forward(self, x):
                # self.tensor cannot be inplaced, however x can,
                # and bc add is commutative we can reverse inputs to add_
                return self.tensor + x

        mod_eager = nn.Sequential(self.getConv(), Mod(torch.rand([4, 32, 1, 1])))
        mod = self.freezeAndConvert(mod_eager)
        FileCheck().check("aten::add_").run(mod.graph)
        self.checkResults(mod_eager, mod)
