# Owner(s): ["oncall: jit"]

import copy
import io
import os
import sys
import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable
from torch.testing import FileCheck


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
import warnings

# Standard library
from collections import namedtuple
from itertools import chain
from typing import Dict, List, Optional, Tuple

from torch import Tensor
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_utils import (
    enable_profiling_mode_for_profiling_tests,
    IS_SANDCASTLE,
    skipIfCompiledWithoutNumpy,
    skipIfCrossRef,
    skipIfTorchDynamo,
    suppress_warnings,
    TemporaryFileName,
)
from torch.testing._internal.jit_utils import (
    _tmp_donotuse_dont_inline_everything,
    _trace,
    enable_cpu_fuser,
    JitTestCase,
    make_global,
    RUN_CUDA,
    RUN_CUDA_MULTI_GPU,
)


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestTracer(JitTestCase):
    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
    def test_large_nbr_kernel_args(self):
        class Recurrence(nn.Module):
            def __init__(self, seq_len):
                super().__init__()
                self.seq_len = seq_len

            def forward(self, input):
                input = input.transpose(0, 1)

                # Main loop
                output = []
                for i in range(self.seq_len):
                    b = input[i] * 2
                    output.append(b)

                output = torch.cat(output, 0).view(input.size(0), *output[0].size())
                output = output.transpose(0, 1)
                return output

        input_size = 8
        batch_size = 2
        seq_len = 130

        rec = Recurrence(seq_len)
        input = torch.rand(batch_size, seq_len, input_size)

        torch.cuda.set_device(0)
        rec = rec.cuda()
        input = input.cuda()

        traced_rec = torch.jit.trace(rec, (input))

    def test_trace_legacy_ctor(self):
        class MyModule(nn.Module):
            def forward(self, x):
                return (x + 1, torch.FloatTensor([0]))

        traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2))

    def test_simple(self):
        x = torch.tensor([0.4], requires_grad=True)
        y = torch.tensor([0.7], requires_grad=True)

        def f(x, y):
            return torch.sigmoid(torch.tanh(x * (x + y)))

        self.checkTrace(f, (x, y))

    def test_trace_checking_with_global_name(self):
        class MyClass(torch.nn.Module):
            def forward(self, xs: List[Tensor]):
                y = torch.cat(xs, dim=0)
                return y

        model = MyClass()
        # Simulate these inputs being in the globals, like they would be if,
        # e.g. they were defined outermost scope of a script
        global input1, input2
        input1 = torch.ones(2, 2)
        input2 = torch.ones(2, 2)
        m2 = torch.jit.trace(model, ((input1, input2),))

    def test_trace_aliased_parameter(self):
        class M(nn.Module):
            def __init__(self, x):
                super().__init__()
                self.x = nn.Parameter(x)

            def forward(self, y):
                return self.x + y

        m = M(torch.rand(3, 4))
        r = torch.jit.trace(m, m.x)
        t2 = torch.rand(3, 4)
        self.assertEqual(r(t2), m.x + t2)

    def test_trace_nested_fn(self):
        class TracedInlineDecision(torch.nn.Module):
            def forward(self, x, flag):
                @torch.jit.script
                def make_decision(flag, x):
                    if flag:
                        return x
                    else:
                        return torch.zeros_like(x)

                x = torch.neg(x)
                return make_decision(flag, x)

        decision = TracedInlineDecision()
        torch.jit.trace(
            decision,
            (torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)),
            check_trace=True,
        )

    def test_trace_single_tuple(self):
        x = torch.tensor(2.0)

        def f2(x):
            return (x,)

        jit_f2 = torch.jit.trace(f2, x)
        assert f2(x) == jit_f2(x)  # fails

    def test_trace_out_operator_with_two_output(self):
        example_input = torch.rand(2, 8)
        out_1, out_2 = torch.cummax(example_input, 1)

        def run_cummax(example_input, out_1, out_2):
            output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2))
            return output_1, output_2

        trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2))

    def test_trace_namedtuple(self):
        Point = namedtuple("point", ["x", "y"])

        def f(p):
            if type(p) is tuple:
                p = Point(*p)
            return p.x + p.y

        p = Point(torch.randn(1), torch.randn(1))
        traced = torch.jit.trace(f, (p,))
        self.assertEqual(f(p), traced(p))

    def test_trace_topk(self):
        class M(torch.nn.Module):
            def forward(self, x, y):
                return x.topk(y, dim=1)[1]

        mod = M()
        inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17))
        traced_func = torch.jit.trace(mod, inputs)

        test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8))
        eager_out = mod(*test_inputs)
        traced_out = traced_func(*test_inputs)
        self.assertNotWarn(
            lambda: traced_func(*test_inputs),
            "Shouldn't throw slicing related warn here",
        )
        self.assertEqual(eager_out, traced_out)

        test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12))
        eager_out = mod(*test_inputs)
        traced_out = traced_func(*test_inputs)
        self.assertNotWarn(
            lambda: traced_func(*test_inputs),
            "Shouldn't throw slicing related warn here",
        )
        self.assertEqual(eager_out, traced_out)

    def test_typeas_trace_check(self):
        a = torch.tensor([0.4], requires_grad=True)
        b = torch.tensor([0.7], requires_grad=True)

        def f(x, y):
            return x.type_as(y)

        trace = torch.jit.trace(f, (a, b))

    def test_trace_index(self):
        x = torch.tensor([0.4], requires_grad=True)
        y = torch.tensor([0], dtype=torch.int64)

        def fn(x, y):
            return x[y]

        fn_traced = torch.jit.trace(
            fn,
            (
                x,
                y,
            ),
        )

        self.assertEqual(fn(x, y), fn_traced(x, y))

    # Backwards tracing was broken for indexing by a constant,
    # because it's internally implemented using as_strided,
    # and we attempted to trace its derivative (which is not
    # currently supported.)  It currently works because
    # slice() is now not marked as traceable.
    def test_trace_index_constant(self):
        x = torch.tensor([0.4], requires_grad=True)

        def fn(x):
            return x[0]

        def run(f):
            y = f(x)
            grad = torch.autograd.grad(y, x)[0].clone()
            return y, grad

        traced_fn = torch.jit.trace(fn, torch.ones(1))
        self.assertEqual(run(fn), run(traced_fn))

    def test_index_put(self):
        ten = torch.zeros(3, 3)
        mask = torch.tensor(
            [[True, True, True], [True, False, False], [True, True, False]]
        )

        def test_fn(ten, mask):
            ten[mask] = torch.ones(6)
            return ten

        traced_test_fn = torch.jit.trace(test_fn, (ten, mask))

        ten = torch.rand(3, 3)
        self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))

    def test_canonicalize_tensor_iterator(self):
        x = torch.randn(4, 4)

        def f(x):
            x = x + 2
            x = x - 4
            x = x * 6
            x = x / 8
            return x

        traced = torch.jit.trace(f, (x,))
        f(x)
        graph = traced.graph_for(x)
        # There should be 4 int constants for the right sides of operators, plus one
        # for the alpha argument for add and sub
        self.assertTrue(str(traced.graph_for(x)).count(": int = prim::Constant") == 5)

    @suppress_warnings
    def test_constant(self):
        x = torch.randn(2, 2, requires_grad=True)

        def f(x):
            return x.matmul(torch.diag(torch.tensor([2.0, 2.0])))

        self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))

    def test_wrapped_number(self):
        # Scalar's get converted to 'wrapped' tensors of default tensor type.
        # Wrapped tensors behave differently in certain promotion operations:
        # float_tensor * double -> float but wrapped_float * double -> double.
        # This can cause issues in check-trace if not handled correctly in
        # `aten::isclose()`.

        def foobar():
            x = -10000.0
            result = x * torch.ones(1, dtype=torch.float)
            return result

        scripted = torch.jit.trace(foobar, (), check_trace=True)

    def test_inplace_transplant(self):
        x = torch.tensor([0.0], requires_grad=True)

        def fn(x):
            y = x.clone()
            y.add_(2)
            y.add_(3)
            return y

        g, _ = torch.jit._get_trace_graph(fn, (x,))
        self.run_pass("dce", g)
        FileCheck().check_count("aten::clone", 1, exactly=True).check_count(
            "aten::add_", 2, exactly=True
        ).check_next("return").run(str(g))
        self.assertExportImport(g, (x,))

    def test_inplace_flags(self):
        class InplaceFn(Function):
            @staticmethod
            def forward(ctx, x):
                ctx.mark_dirty(x)
                return x.add_(1)

            @staticmethod
            def backward(ctx, go):
                return go

        class RegularFn(Function):
            @staticmethod
            def forward(ctx, x):
                return x.add(1)

            @staticmethod
            def backward(ctx, go):
                return go

        x = torch.tensor([0.0], requires_grad=True)

        def fn(x):
            y = RegularFn.apply(x)
            y = InplaceFn.apply(y)
            y = InplaceFn.apply(y)
            y = RegularFn.apply(y)
            return y

        trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True)
        self.run_pass("dce", trace_graph)
        ops = list(trace_graph.nodes())
        for op in ops:
            self.assertTrue(op.hasAttribute("inplace"))
        inplace_flags = [False, True, True, False]
        for op, is_inplace in zip(ops, inplace_flags):
            self.assertEqual(op.i("inplace"), is_inplace)

    def test_inplace_check(self):
        class MyInplaceFn(Function):
            @staticmethod
            def forward(self, x):
                x.add_(1)
                self.mark_dirty(x)
                return x

            @staticmethod
            def backward(self, grad):
                return grad

        def fn(x):
            return MyInplaceFn.apply(x)

        x = torch.randn(5, 5)
        ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
        with self.assertRaisesRegex(RuntimeError, "inplace MyInplaceFn"):
            ge(x)

    def test_force_outplace_check_fill(self):
        def f(x):
            return torch.empty(x.shape).fill_(7)

        x = torch.randn(10, 15)
        ft = torch.jit.trace(f, x, _force_outplace=True)
        self.assertEqual(f(x), ft(x))

    def test_force_outplace_check_zero(self):
        def f(x):
            return torch.empty(x.shape).zero_()

        x = torch.randn(10, 15)
        ft = torch.jit.trace(f, x, _force_outplace=True)
        self.assertEqual(f(x), ft(x))

    def do_trace_size(self, requires_grad):
        def fn(x):
            return x.view(x.shape[1] * 2, x.size(0), 2)

        x = torch.randn(5, 2, 4, requires_grad=requires_grad)
        y = torch.randn(4, 8, 4, requires_grad=requires_grad)

        # Check that it behaves as expected
        traced_fn = torch.jit.trace(fn, x)
        self.assertEqual(traced_fn(y), fn(y))
        self.assertEqual(traced_fn(x), fn(x))

    def test_trace_size(self):
        self.do_trace_size(False)

    # test the different graph_executor path that happens when
    # gradients are required and sizes are involved
    def test_trace_size_with_grad(self):
        self.do_trace_size(True)

    def test_trace_numel(self):
        def fn(x):
            return x.numel()

        x = torch.randn(2, 3, 4)
        y = torch.randn(4, 5, 6)

        traced_fn = torch.jit.trace(fn, x)
        self.assertEqual(traced_fn(y), fn(y))
        self.assertEqual(traced_fn(x), fn(x))

    def do_trace_arange(self, requires_grad):
        def arange(x):
            return torch.arange(x.shape[0])

        def arange_scalar(x):
            return torch.arange(12)

        def arange_start_end(x):
            return torch.arange(start=x.shape[0], end=x.shape[0] + 5)

        x = torch.randn(5, 3, 2, requires_grad=requires_grad)
        y = torch.randn(8, 2, 4, requires_grad=requires_grad)

        # Check that it behaves as expected
        traced_arange = torch.jit.trace(arange, x)
        self.assertEqual(traced_arange(y), arange(y))
        self.assertEqual(traced_arange(x), arange(x))

        traced_arange_scalar = torch.jit.trace(arange_scalar, x)
        self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
        self.assertEqual(traced_arange_scalar(x), arange_scalar(x))

        traced_arange_start_end = torch.jit.trace(arange_start_end, x)
        self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
        self.assertEqual(traced_arange_start_end(x), arange_start_end(x))

    def test_trace_arange(self):
        self.do_trace_arange(False)

    # test the different graph_executor path that happens when
    # gradients are required and sizes are involved
    def test_trace_arange_with_grad(self):
        self.do_trace_arange(True)

    # Test that a trace of torch.full(x.shape) doesn't store the shape as a constant
    def test_trace_full_dynamic_shape(self):
        def full_with_shape_like(x):
            return torch.full(x.shape, 2.0)

        x = torch.randn(3, 4)
        ge = torch.jit.trace(full_with_shape_like, example_inputs=x)
        y = torch.randn(2, 7)
        self.assertEqual(ge(y).shape, y.shape)
        self.assertEqual(ge(x).shape, x.shape)

    # Test that the trace of setitem doesn't store shapes as constants
    # Fix https://github.com/pytorch/pytorch/issues/43548
    def test_trace_slice_setitem_dynamic_shape(self):
        def slice_setitem(x, y):
            x[:, 2] = y + 1
            return x

        x = torch.randn(3, 4)
        traced = torch.jit.trace(slice_setitem, (x, x[:, 0]))
        x = torch.randn(10, 5)
        self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0]))

    # Suppression: we are intentionally slicing a tensor, we don't care that it
    # will be constantified
    @suppress_warnings
    def do_trace_slice(self, requires_grad):
        def slice(x):
            results = []
            for i in range(4):
                results.append(x[: x.size(0) - i, i : x.size(2), i:3])
            return tuple(results)

        def slice_select(x):
            results = []
            for i in range(4):
                results.append(x[:, i:, x.size(2) - 5])
            return tuple(results)

        x = torch.randn(5, 6, 7, requires_grad=requires_grad)
        y = torch.randn(7, 8, 9, requires_grad=requires_grad)

        # Check that it behaves as expected
        traced_slice = torch.jit.trace(slice, x)
        self.assertEqual(traced_slice(y), slice(y))
        self.assertEqual(traced_slice(x), slice(x))

        traced_slice_select = torch.jit.trace(slice_select, x)
        self.assertEqual(traced_slice_select(y), slice_select(y))
        self.assertEqual(traced_slice_select(x), slice_select(x))

    def test_trace_slice(self):
        self.do_trace_slice(False)

    # test the different graph_executor path that happens when
    # gradients are required and sizes are involved
    def test_trace_slice_with_grad(self):
        self.do_trace_slice(True)

    def test_trace_casts(self):
        casts = [
            lambda x: x.byte(),
            lambda x: x.float(),
            lambda x: x.cpu(),
            lambda x: x.to(device="cpu"),
            lambda x: x.to(dtype=torch.int64),
            lambda x: x.to(device="cpu", dtype=torch.float),
            lambda x: x.to(x),
        ]

        def assertContainsCast(trace):
            self.assertEqual(
                sum(n.kind() == "aten::to" for n in trace.graph.nodes()), 1
            )

        for cast in casts:
            trace = torch.jit.trace(cast, torch.randn(2, 2))
            assertContainsCast(trace)
            x = torch.randn(2, 2)
            self.assertEqual(trace(x), cast(x))

        def to_tensor(x, y):
            return x.to(y)

        to_tensor_trace = torch.jit.trace(
            to_tensor, (torch.randn(2, 2), torch.randn(1, 8))
        )
        assertContainsCast(to_tensor_trace)
        x, y = torch.randn(2, 2), torch.randn(1, 10)
        self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))

    @skipIfCompiledWithoutNumpy
    @skipIfCrossRef
    def test_trace_warn(self):
        def fn(x):
            int(x)  # Warning 1.
            y = x * 1
            if y:  # Warning 2.
                pass
            q = [x, x * 4]
            z = q[y]
            float(z)  # Warning 3.
            z.tolist()  # Warning 4.
            z.numpy()  # Warning 5.
            for _ in torch.ones(4, 4):  # Warning 6.
                pass
            return z + 4

        with warnings.catch_warnings(record=True) as warns:
            traced_fn = torch.jit.trace(fn, torch.tensor([1]))
        for warn in warns:
            self.assertIs(warn.category, torch.jit.TracerWarning)
        warns = [str(w.message) for w in warns]
        self.assertIn("a Python integer", warns[0])
        self.assertIn("a Python boolean", warns[1])
        self.assertIn("a Python float", warns[2])
        self.assertIn("a Python list", warns[3])
        self.assertIn("a NumPy array", warns[4])
        self.assertIn("Iterating over", warns[5])

    def test_trace_tuple(self):
        def fn(x, y):
            return x, (x * y[1], x * y[0])

        x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
        traced_fn = torch.jit.trace(fn, (x, y))
        self.assertEqual(traced_fn(x, y), fn(x, y))
        # should be a tuple nested within another tuple
        FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next(
            "return"
        ).run(str(traced_fn.graph))
        self.assertExportImport(traced_fn.graph, (x, y))

    def test_trace_random(self):
        def f(mean, std):
            return torch.normal(mean, std)

        traced = torch.jit.trace(
            f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False
        )
        mean, std = torch.zeros(5, 5), torch.ones(5, 5)
        with torch.random.fork_rng(devices=[]):
            output = f(mean, std)
        traced_output = traced(mean, std)
        self.assertEqual(output, traced_output)

    def test_trace_tensor_factory(self):
        def run(**kwargs):
            inputs_require_grads = kwargs.pop("inputs_require_grads", True)

            def fn(x):
                return x + torch.ones(2, 3, **kwargs)

            input_kwargs = kwargs.copy()
            if "out" in input_kwargs:
                del input_kwargs["out"]
            input = torch.ones(2, 3, **input_kwargs)
            self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
            # check we recorded 'ones' and did not just record a constant
            tfn = torch.jit.trace(fn, input)
            self.assertTrue("ones" in str(tfn.graph))

        run()
        run(dtype=torch.int, inputs_require_grads=False)
        run(out=torch.tensor([]))
        if RUN_CUDA:
            run(device="cuda:0")
        if RUN_CUDA_MULTI_GPU:
            run(device="cuda:1")

    def test_trace_indexed_assignment(self):
        def stuff(x, y):
            x = x.clone()
            x[0] = y
            return x

        example = torch.rand(3, 4)
        self.checkTrace(stuff, (example, example[0] + 1))

    # TODO: implement
    @unittest.expectedFailure
    def test_output_unflatten(self):
        """Check that outputs of traced functions retain the original structure and nesting"""

        def fn(x):
            return (
                x * 2,
                (
                    x**2,
                    x + 4,
                    (x + 2,),
                ),
                x * 4,
            )

        self.checkTrace(fn, (torch.randn(2, 2),))

    def test_input_flatten(self):
        """Check that inputs to traced functions are flattened"""

        def fn(x, t):
            y, z = t
            return x * y * z

        inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
        self.checkTrace(fn, inputs)

    def test_input_dict_empty(self):
        def test(d):
            pass

        with self.assertRaises(RuntimeError):
            self.checkTrace(test, {})

    def test_input_dict_remembers_keys(self):
        """Check that the trace remembers which keys were in a dict input"""

        class TestModule(torch.nn.Module):
            def forward(self, dict_input):
                return dict_input["x"]

        input_1 = {"x": torch.tensor(1)}
        m = TestModule()
        m_traced = torch.jit.trace(m, (input_1,))
        self.assertEqual(m_traced(input_1), torch.tensor(1))

        # should work to change the values and not the keys
        input_same_key_different_value = {"x": torch.tensor(2)}
        self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2))

        # error to use something that doesn't have `x`
        input_different_key = {"y": torch.tensor(3)}
        with self.assertRaises(RuntimeError):
            m_traced(input_different_key)

        # it's okay to have additional elements in the dictionary, so long as 'x' is there
        input_additional_key = {"x": torch.tensor(4), "y": torch.tensor(3)}
        self.assertEqual(m_traced(input_additional_key), torch.tensor(4))

    def test_input_dict_insertion_order(self):
        """Check that dictionary access doesn't care about insertion order"""

        class TestModule(torch.nn.Module):
            def forward(self, dict_input):
                return dict_input["x"], dict_input["y"]

        input_x_then_y = {}
        input_x_then_y["x"] = torch.tensor(1)
        input_x_then_y["y"] = torch.tensor(2)

        m = TestModule()
        m_traced = torch.jit.trace(m, (input_x_then_y,))

        self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2)))

        input_y_then_x = {}
        input_y_then_x["y"] = torch.tensor(4)
        input_y_then_x["x"] = torch.tensor(3)

        self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4)))

    def test_input_dict_recursive(self):
        class TestModule(torch.nn.Module):
            def forward(self, dict_input):
                return dict_input["x"][1]

        input_1 = {"x": {1: torch.tensor(1)}}
        m = TestModule()
        m_traced = torch.jit.trace(m, (input_1,))

        input_2 = {"x": {1: torch.tensor(2)}}
        self.assertEqual(m_traced(input_2), torch.tensor(2))

    def test_input_dict_checkTrace_mut(self):
        def test(d):
            d["x"].tanh_()
            return d["x"]

        inputs = {"x": torch.rand(3, 4), "y": torch.rand(3, 4)}
        self.checkTrace(test, (inputs,), inputs_require_grads=False)

    def test_input_dict_unify(self):
        def test(d):
            return d["int"], d["float"]

        inputs = {
            "int": torch.ones((2, 2), dtype=torch.int32),
            "float": torch.ones((2, 2), dtype=torch.float32),
        }
        self.checkTrace(test, (inputs,), inputs_require_grads=False)

    def test_input_tuple_of_dicts(self):
        def test(t):
            d = t[0]
            return d["x"]["y"]

        inputs = {"x": {"y": torch.rand(2, 3)}}
        self.checkTrace(test, ((inputs, inputs),), allow_unused=True)

    def test_input_dict_of_dicts(self):
        def test(d):
            return d["x"]["y"]

        nested_input = {"y": torch.rand(2, 3)}
        unified_nested = {"y": torch.rand(3, 2)}
        inputs = {"x": nested_input, "force_unify": unified_nested}
        self.checkTrace(test, (inputs,), allow_unused=True)

    def test_input_dict_of_lists(self):
        def test(d):
            return d["x"][0]

        inputs = {"x": [torch.rand(3, 2)]}
        self.checkTrace(test, (inputs,))

    def test_input_list_toplevel_flatten(self):
        def test(t1, t2):
            return torch.add(t1, t2)

        inputs = [torch.ones(2, 2), torch.rand(2, 2)]
        self.checkTrace(test, inputs)

    def test_input_list_toplevel_flatten_direct(self):
        class Test(torch.nn.Module):
            def forward(self, t1, t2):
                return torch.add(t1, t2)

        inputs = [torch.ones(2, 2), torch.rand(2, 2)]
        torch.jit.trace(Test(), inputs)

    def test_input_list_of_tuples(self):
        def test(l):
            return l[0][0]

        inputs = [(torch.ones(2, 2),)]
        self.checkTrace(test, (inputs,))

    def test_input_dict_empty_list(self):
        def test(d):
            pass

        inputs = {1: []}
        with self.assertRaisesRegex(RuntimeError, "List trace"):
            self.checkTrace(test, (inputs,))

    def test_input_list_mixed_type(self):
        def test(d):
            pass

        inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))]
        with self.assertRaisesRegex(RuntimeError, "consistent"):
            self.checkTrace(test, (inputs,))

    def test_conv(self):
        x = torch.ones(20, 16, 50, 40)
        g, outputs, inputs = torch.jit._get_trace_graph(
            nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True
        )
        m = self.createFunctionFromGraph(g)
        self.assertEqual(outputs, m(*inputs))

    def test_max_pool(self):
        x = torch.rand(20, 16, 10, 10)

        def max_pool2d(x):
            return F.max_pool2d(x, 2) + 2

        trace = torch.jit.trace(max_pool2d, (x))
        graph = trace.graph_for(x)
        FileCheck().check("aten::max_pool2d(").run(graph)
        self.assertEqual(max_pool2d(x), trace(x))

    def test_nested_inplace(self):
        x = torch.randn(2, 2)
        g, outputs, inputs = torch.jit._get_trace_graph(
            lambda x: F.threshold(x, 0, 0, inplace=True), (x,), return_inputs=True
        )
        m = self.createFunctionFromGraph(g)
        self.assertEqual(outputs, m(*inputs))
        FileCheck().check("threshold_").run(str(g))
        self.assertExportImport(g, (x,))

    def test_repeated_input(self):
        def fn(a, b):
            return a + b

        ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
        inputs = set(ge.graph.inputs())
        # three instead of 2 because the export/import in checkTrace adds a
        # `self` module argument
        self.assertTrue(len(inputs) == 3)

    def test_repeated_output(self):
        def fn(a, b):
            z = a + b
            return z, z

        ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
        tuple_output = list(ge.graph.outputs())[0]
        tuple_inputs = list(tuple_output.node().inputs())
        self.assertTrue(tuple_inputs[0] == tuple_inputs[1])

    def test_inplace_copy(self):
        x = torch.randn(4, 4, requires_grad=True)

        def f(x):
            out = torch.zeros(x.size())
            out.copy_(x)
            return out

        g, outputs, inputs = torch.jit._get_trace_graph(f, (x,), return_inputs=True)
        self.run_pass("dce", g)
        m = self.createFunctionFromGraph(g)
        self.assertEqual(outputs, m(*inputs))
        self.assertExportImport(g, (x,))

    def test_inplace_copy_force_outplace(self):
        x = torch.randn(4, 4, requires_grad=True)

        def f(x):
            out = torch.zeros(x.size())
            out.copy_(x)
            return out

        g, outputs, inputs = torch.jit._get_trace_graph(
            f, (x,), return_inputs=True, _force_outplace=True
        )
        self.run_pass("dce", g)
        m = self.createFunctionFromGraph(g)
        self.assertEqual(outputs, m(*inputs))
        self.assertExportImport(g, (x,))
        FileCheck().check("expand_as").run(str(g))

    def test_shared_param(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = self.a = nn.Parameter(torch.randn(2, 2))

            def forward(self, x):
                return x * self.a + self.b

        m = MyModule()
        g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),))
        self.run_pass("dce", g)
        self.assertEqual(len(list(g.inputs())), 2)
        FileCheck().check("mul").check("add").run(str(g))

    def run_ge_tests(self, optimize, use_cuda):
        with enable_profiling_mode_for_profiling_tests():
            with torch.jit.optimized_execution(optimize):

                def rand(*args):
                    t = torch.rand(*args).float()
                    if use_cuda:
                        t = t.cuda()
                    return t

                self.checkTrace(
                    lambda a, b: a * b + b, [rand(1), rand(1)], [rand(2, 3), rand(2, 3)]
                )
                # trivial identity
                self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)])

                def foo(a):
                    t = a * a
                    return t * t, 4 * t

                self.checkTrace(foo, [rand(1)])
                # unused input
                self.checkTrace(
                    lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True
                )
                # test outputs that do not get used in grad
                self.checkTrace(foo, [rand(1)], drop=1)
                # test autograd fallback
                self.checkTrace(
                    lambda a, b: a * b / (a - 2 * b) + b, [rand(1), rand(1)]
                )

    def test_ge_unoptimized(self):
        self.run_ge_tests(False, False)

    @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
    @enable_cpu_fuser
    def test_ge_optimized(self):
        with enable_profiling_mode_for_profiling_tests():
            self.run_ge_tests(True, False)

    @unittest.skipIf(not RUN_CUDA, "requires CUDA")
    def test_ge_cuda(self):
        self.run_ge_tests(True, True)

    # more manual test of graph executor that can be used as a scratchpad
    def test_ge(self):
        def foo(a, b):
            return a * b / (a - b) + b

        V = Variable
        a, b = V(torch.rand(1)), V(torch.rand(1))
        ge = torch.jit.trace(foo, (a, b))
        a, b = V(torch.rand(1), requires_grad=True), V(
            torch.rand(1), requires_grad=True
        )
        (r,) = ge(a, b)
        da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)

        l2 = da * db + db * db
        g2result = torch.autograd.grad(l2, [da, db])

        r = foo(a, b)
        da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
        self.assertEqual(da, da2)
        self.assertEqual(db, db2)
        l3 = da2 * db2 + db2 * db2
        g2result2 = torch.autograd.grad(l3, [da2, db2])
        self.assertEqual(g2result, g2result2)

    def test_trace_annotation(self):
        @_trace(torch.rand(1))
        def foo(a):
            return a + a + a

        x = torch.randn(5, 5)
        self.assertEqual(foo(x), x + x + x)

    @unittest.skipIf(not RUN_CUDA, "calls .cuda()")
    # By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision.
    # We want float tensors to be computed at full precision in order to use the default precision
    @with_tf32_off
    def test_traced_module_cuda(self):
        class Model(nn.Module):
            def __init__(self, num_features, num_layers):
                super().__init__()
                self.num_layers = num_layers
                layers = [
                    [nn.Linear(num_features, num_features), nn.Sigmoid()]
                    for _ in range(num_layers)
                ]
                self.submodule = nn.Sequential(*chain(*layers))

            def forward(self, x):
                for i in range(self.num_layers):
                    x = self.submodule[i](x) + x
                return x

        model = Model(5, 3)
        x = torch.randn(2, 5)
        traced_model = torch.jit.trace(model, x)

        # We're missing some attributes these modules had initially. Make sure we can
        # still get the __repr__()
        model.__repr__()

        # XXX: indexing sequentials is broken
        linear_submodule = next(iter(traced_model.submodule._modules.values()))

        # All attributes that aren't parameters should raise
        with self.assertRaises(AttributeError):
            linear_submodule.in_features
        linear_submodule.weight
        linear_submodule.weight = nn.Parameter(
            torch.randn(linear_submodule.weight.shape)
        )
        with self.assertRaises(RuntimeError):
            del linear_submodule.weight

        # Submodules can't be called
        with self.assertRaises(RuntimeError):
            linear_submodule(x)

        # Type casts
        linear_submodule.cuda()
        traced_model.float().cuda()
        cuda_out = traced_model(x.float().cuda())
        traced_model.cpu()
        cpu_out = traced_model(x.float())
        self.assertEqual(cpu_out, cuda_out)
        traced_model.to("cuda")
        cuda_out = traced_model(x.float().cuda())
        traced_model.to("cpu")
        cpu_out = traced_model(x.float())
        self.assertEqual(cpu_out, cuda_out)
        traced_model.to(torch.get_default_dtype())

        # state_dict + load_state_dict
        state = {k: v.clone() for k, v in traced_model.state_dict().items()}
        new_state = {k: v.clone().fill_(1) for k, v in state.items()}
        out = traced_model(x)
        traced_model.load_state_dict(new_state)
        out_ones = traced_model(x)
        traced_model.load_state_dict(state)
        out_state = traced_model(x)
        self.assertEqual(out, out_state)
        self.assertNotEqual(out, out_ones)

    @unittest.skipIf(not RUN_CUDA, "uses cuda")
    def test_type_same_device(self):
        class Model(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.dtype = torch.float16

            def forward(self, x=None):
                h = x.type(self.dtype)
                return h

        a = Model()
        b = torch.jit.trace(
            a, example_inputs=(torch.ones([1], device=torch.device("cuda")),)
        )
        FileCheck().check_not("device").run(b.code)

    def test_export_no_reorder(self):
        def func(a, b):
            return a * b / (a - 2 * b) + b

        recording_inputs = [
            torch.tensor(
                [0.55619788169860839844], dtype=torch.float32, requires_grad=True
            ),
            torch.tensor(
                [0.25947844982147216797], dtype=torch.float32, requires_grad=True
            ),
        ]

        ge1 = torch.jit.trace(func, recording_inputs)
        ge2 = self.getExportImportCopy(ge1)

        outputs_ge1 = ge1(*recording_inputs)
        outputs_ge2 = ge2(*recording_inputs)

        grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
        grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
        self.assertTrue(outputs_ge1 == outputs_ge2)
        self.assertTrue(grad_ge1 == grad_ge2)

    def test_python_function(self):
        class MyFn(Function):
            @staticmethod
            def forward(ctx, x):
                return x + 1

            @staticmethod
            def backward(ctx, grad_output):
                return grad_output

        @_trace(torch.zeros(2))
        def fn(x):
            return MyFn.apply(x + 2) + 3

        x = torch.tensor([1.0, 2.0, 3.0])
        y = torch.randn(2, 2, requires_grad=True)
        fn(x)
        fn(y)

    def test_python_function_tup(self):
        class MyFn(Function):
            @staticmethod
            def forward(ctx, x):
                return x + 1, x - 1

            @staticmethod
            def backward(ctx, grad_output):
                return grad_output, grad_output

        @_trace(torch.zeros(2))
        def fn(x):
            a, b = MyFn.apply(x + 2)
            return a + b + 3

        x = torch.tensor([1.0, 2.0, 3.0])
        y = torch.randn(2, 2, requires_grad=True)
        fn(x)
        fn(y)

    def test_trace_detach(self):
        def foo(x, w):
            return torch.matmul(x, w).detach()

        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))

        FileCheck().check("matmul").check("detach").run(str(traced.graph))
        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
        traced_result = traced(x, w)
        self.assertEqual(foo(x, w), traced_result)
        self.assertFalse(traced_result.requires_grad)
        self.assertIsNone(traced_result.grad_fn)

    def test_trace_detach_redispatch(self):
        def foo(x, w):
            y = torch.matmul(x, w)
            assert y.requires_grad
            y = y.detach()
            # Make sure trace kernel redispatches to the right lower kernel.
            assert not y.requires_grad
            return y

        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
        # With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
        torch.jit.trace(foo, (x, w), check_trace=False)

    def test_trace_detach_inplace(self):
        def foo(x, w):
            y = torch.matmul(x, w)
            y.detach_()
            return y

        traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))

        FileCheck().check("matmul").check("detach(").run(str(traced.graph))
        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
        traced_result = traced(x, w)
        self.assertEqual(foo(x, w), traced_result)
        self.assertFalse(traced_result.requires_grad)
        self.assertIsNone(traced_result.grad_fn)

    def test_trace_detach_inplace_redispatch(self):
        def foo(x, w):
            y = torch.matmul(x, w)
            assert y.requires_grad
            y.detach_()
            # Make sure trace kernel redispatches to the right lower kernel.
            assert not y.requires_grad
            return y

        x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
        # With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
        torch.jit.trace(foo, (x, w), check_trace=False)

    def test_trace_slice_full_dim(self):
        def foo(x):
            return x[0:5, 0] + 1.0

        traced = torch.jit.trace(foo, (torch.rand(5, 4),))
        test_x = torch.rand(6, 3)
        self.assertEqual(foo(test_x), traced(test_x))

    def test_trace_dict_input(self):
        class Bar(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Foo()

            def forward(self, a, b):
                return self.foo({"a": a, "b": b})["a"]

        class Foo(torch.nn.Module):
            def forward(self, x):
                return {"a": x["a"] * x["b"]}

        x = (torch.rand(3), torch.rand(3))
        model = Bar()
        self.checkTrace(model, x)

    def test_trace_dict_output(self):
        class TraceDictStrTensor(torch.nn.Module):
            def forward(self, a, b):
                return {"a": a, "b": b}

        class TraceDictTensorTensor(torch.nn.Module):
            def forward(self, a, b):
                return {a: b, b: a}

        x = (torch.rand(3), torch.rand(3))
        with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"):
            torch.jit.trace(TraceDictStrTensor(), x)

        traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False)
        self.assertEqual(traced_dict_str_mod(*x), {"a": x[0], "b": x[1]})

        traced_dict_tensor_mod = torch.jit.trace(
            TraceDictTensorTensor(), x, strict=False
        )
        self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]})

    def test_trace_with_tensor_list_output(self):
        def f():
            return [torch.zeros(1), torch.zeros(5)]

        with self.assertWarnsRegex(
            torch.jit.TracerWarning, "cause the trace to be incorrect"
        ):
            torch.jit.trace(f, [])
        traced_non_strict_f = torch.jit.trace(f, [], strict=False)
        self.assertEqual(traced_non_strict_f(), f())

    def test_trace_with_number_list_output(self):
        def f():
            return [1, 5]

        with self.assertRaisesRegex(
            RuntimeError, r"Only tensors.+can be output from traced functions"
        ):
            traced_f = torch.jit.trace(f, [])

    def test_trace_with_nested_tensor_list_output(self):
        def f():
            return [[torch.zeros(1)], [torch.zeros(5)]]

        with self.assertRaisesRegex(
            RuntimeError, r"Only tensors.+can be output from traced functions"
        ):
            traced_f = torch.jit.trace(f, [])

    def test_trace_with_nested_strided_tensor_output(self):
        @torch.jit.script
        def nt_construct(values, kv_lengths):
            kv_lengths_list: List[int] = kv_lengths.tolist()
            return torch._nested_tensor_from_tensor_list(
                list(values.split(kv_lengths_list, dim=0)), None, None, None, None
            )

        def f(x, offsets):
            kv_lengths = offsets[1:] - offsets[:-1]
            return nt_construct(x, kv_lengths).cos()

        x = torch.rand(5, 4)
        offsets = torch.tensor([0, 2, 5])
        ref = f(x, offsets)
        f_t = torch.jit.trace(f, (x, offsets))
        res = f_t(x, offsets)
        self.assertEqual(ref, res)
        x2 = torch.rand((8, 4))
        offsets2 = torch.tensor([0, 2, 4, 8])
        self.assertEqual(f(x2, offsets2), f_t(x2, offsets2))

    def test_trace_variable_instantiation(self):
        def random_foo(x):
            return Variable(Variable(x) + 1.0)

        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))

        x = torch.rand(5, 6)
        self.assertEqual(random_foo(x), random_foo_traced(x))

    def test_trace_slice_expr_complete_type(self):
        def random_foo(x):
            return x + 1.0

        random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))

        @torch.jit.script
        def random_bar(x):
            return random_foo_traced(x)[0:1]

        x = torch.rand(3, 4)
        self.assertEqual(random_bar(x), (x + 1)[0:1])

    def test_trace_inline_shape(self):
        # testing peephole optimization of size is turned into a constant
        # in script fn

        @torch.jit.script
        def tensor_size(x: torch.Tensor) -> torch.Tensor:
            return torch.tensor([x.size()[0]])

        self.assertEqual(
            tensor_size(
                torch.rand(
                    15,
                )
            ),
            torch.tensor([15]),
        )

        traced_tensor_size = torch.jit.trace(
            tensor_size,
            torch.rand(
                7,
            ),
        )

        self.assertEqual(
            traced_tensor_size(
                torch.rand(
                    15,
                )
            ),
            torch.tensor([15]),
        )

        @torch.jit.script
        def use_device(x):
            return torch.zeros_like(x, device=x.device)

        def foo(x):
            return use_device(x)

        traced_tensor_size = torch.jit.trace(
            foo,
            torch.rand(
                7,
            ),
        )
        self.run_pass("inline", traced_tensor_size.graph)
        FileCheck().check("prim::device").run(traced_tensor_size.graph)

    def test_trace_save(self):
        def fn(x):
            return x + 2

        def check(func):
            with TemporaryFileName() as fname:
                func.save(fname)
                loaded = torch.jit.load(fname)
                input = torch.randn(2, 2)
                self.assertEqual(func(input), loaded(input))

        out = torch.jit.trace(fn, (torch.ones(2, 2),))
        check(out)

    def test_trace_optioanl_dtype(self):
        class Test(torch.nn.Module):
            def forward(self):
                return torch.arange(5)

        traced = torch.jit.trace(Test(), ())
        torch.allclose(traced(), Test()())

    def test_trace_save_load_copy(self):
        class Test(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 3)

            def forward(self, x):
                return self.conv(x)

        traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224))
        buffer = io.BytesIO()
        torch.jit.save(traced, buffer)
        buffer.seek(0)
        loaded = torch.jit.load(buffer)
        # should work
        copy.copy(loaded)
        copy.deepcopy(loaded)

    def test_trace_export_fns(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 3

            @torch.jit.export
            def __getstate__(self):
                return (3, self.training)

            @torch.jit.export
            def __setstate__(self, state):
                self.a = state[0]
                self.training = state[1]

            def forward(self, x):
                return x + self.a

        f = Foo()

        traced = torch.jit.trace(f, (torch.rand(3, 4),))
        expected_names = ["__getstate__", "__setstate__"]

        def check(mod):
            self.assertTrue(
                all(name in mod._c._method_names() for name in expected_names)
            )

        check(traced)

        imported = self.getExportImportCopy(traced)
        check(imported)

    def test_trace_export_fns_recursive(self):
        class Foo(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = 3

            @torch.jit.export
            def __getstate__(self):
                return (3, self.training)

            @torch.jit.export
            def __setstate__(self, state):
                self.a = state[0]
                self.training = state[1]

            def forward(self, x):
                return x + self.a

        class Wrapper(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.foo = Foo()

            def forward(self, x):
                return self.foo(x)

        f = Wrapper()

        traced = torch.jit.trace(f, (torch.rand(3, 4),))
        expected_names = ["__getstate__", "__setstate__"]

        def check(mod):
            self.assertTrue(
                all(name in mod._c._method_names() for name in expected_names)
            )

        check(traced.foo)

        imported = self.getExportImportCopy(traced)
        check(imported.foo)

        # Note that Bar's forward can only be traced, but not scripted
        class Bar(nn.Module):
            @torch.jit.export
            def addTwo(self, x):
                return x + 2

            def forward(self, input):
                return (lambda a: a + 1)(input)  # noqa: PLC3002

        # When tracing Bar as a submodule, we only want to script the
        # exported methods, and we want to keep the forwards still
        # being traced.
        class WrapperExports(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.bar = Bar()

            @torch.jit.export
            def addOne(self, x):
                return x + 1

            def forward(self, x):
                return self.bar(x)

        f = WrapperExports()

        traced = torch.jit.trace(f, (torch.rand(3, 4),))
        expected_names = ["addOne"]
        check(traced)

    def test_trace_autograd_function(self):
        class TestFunc(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                return torch.neg(input)

            @staticmethod
            def backward(ctx, grad_output):
                return torch.neg(grad_output)

        class TracedModule(torch.nn.Module):
            def forward(self, x):
                return torch.relu(TestFunc.apply(x))

        class Wrapper(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.tm = TracedModule()

            def forward(self, x):
                return self.tm(x)

        traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),))

    def test_trace_multi_output_function(self):
        # An autograd.Function with two outputs.
        # It swaps inputs so we can check if shape
        # handling is correct in TorchScript.
        class Foo(torch.autograd.Function):
            @staticmethod
            def forward(ctx, x, y):
                return y, x

            @staticmethod
            def backward(ctx, du, dv):
                return dv, du

        class Bar(torch.nn.Module):
            def forward(self, x, y):
                x = x.relu()
                y = y.relu()
                z = Foo.apply(x, y)
                return z

        x = torch.rand(3, 2, dtype=torch.double)
        y = torch.rand(1, 2, dtype=torch.double)

        # Generate JIT IR.
        traced = torch.jit.trace(Bar(), (x, y))
        print(traced.graph)

        # Expected output schema of the custom autograd.Function.
        schema = (
            "(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), "
            "Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) "
            "= ^Foo"
        )

        # See if expected schema exists.
        FileCheck().check(schema).run(traced.graph)

        # Also examine if the graph is runnable and produces
        # the right result.
        u, v = traced(x, y)
        self.assertEqual(u, y)
        self.assertEqual(v, x)

    def test_interpolate_trace(self):
        class test(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)

            def forward(self, x):
                y = self.conv(x)
                w = nn.functional.interpolate(
                    y, mode="bilinear", align_corners=False, scale_factor=3
                )
                return w

        f = test()
        # no failure
        g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
        x = torch.zeros(1, 1, 14, 14)
        # constants not baked in
        self.assertEqual(g(x), f(x))

    @_tmp_donotuse_dont_inline_everything
    def test_trace_optional(self):
        @torch.jit.script
        def test(x: Optional[Tensor]):
            if x is None:
                return torch.zeros(1)
            else:
                return x

        def test_none():
            return test(None)

        def test_tensor():
            return test(torch.zeros(2))

        f_none = torch.jit.trace(test_none, ())
        self.assertEqual(f_none(), torch.zeros(1))

        f_tensor = torch.jit.trace(test_tensor, ())
        self.assertEqual(f_tensor(), torch.zeros(2))

        graph = f_tensor.graph
        FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph)

    def test_trace_nested_datatypes(self):
        @torch.jit.script
        def foo(x):
            return [[x + 1, x - 1], [x + 2, x - 2]]

        def bar(x):
            list_stuff = foo(x)
            return list_stuff[0][0], list_stuff[1][1]

        traced = torch.jit.trace(bar, torch.rand(3, 4))
        x = torch.rand(5, 6)
        self.assertEqual(bar(x), traced(x))

    @_tmp_donotuse_dont_inline_everything
    def test_call_traced_fn_from_traced_module(self):
        @_trace(torch.rand(3, 4))
        def traced_fn(x):
            return torch.neg(x)

        class TracedModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 5))

            def forward(self, x):
                return traced_fn(torch.mm(x, self.param))

        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))

        # Note: neg op from the traced function should be properly inlined
        FileCheck().check("aten::mm").check('name="traced_fn"').check_next(
            "prim::CallFunction"
        ).run(str(tm.graph))

    @_tmp_donotuse_dont_inline_everything
    def test_call_traced_module_from_traced_module(self):
        class TracedModule1(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(5, 7))

            def forward(self, x):
                return torch.mm(x, self.param)

        class TracedModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 5))
                self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))

            def forward(self, x):
                return self.mod(torch.mm(x, self.param)) + 1.0

        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))

        FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
            "forward"
        ).check("aten::add").run(str(tm.graph))

    def test_index_put_trace_with_view(self):
        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
        def test_index_put(target, indices, rhs):
            target[indices] = rhs
            return target

        FileCheck().check("aten::view").check("index_put_").run(
            str(test_index_put.graph)
        )

    def test_index_put_trace_without_view(self):
        @_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
        def test_index_put(target, indices, rhs):
            target[indices] = rhs
            return target

        FileCheck().check_not("aten::view").check("index_put_").run(
            str(test_index_put.graph)
        )

    @suppress_warnings
    def test_trace_checker_dot_data(self):
        with self.assertRaisesRegex(
            torch.jit.TracingCheckError,
            r"Tensor-valued Constant nodes differed in value " r"across invocations",
        ):

            @_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
            def foo(x):
                y = x.data
                return x + y

    @suppress_warnings
    def test_trace_checker_control_flow(self):
        def foo(x):
            for _ in range(x.size(0)):
                x = torch.neg(x)
            return x

        with self.assertRaisesRegex(
            torch.jit.TracingCheckError, r"Graphs differed across invocations!"
        ):
            torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])

    @suppress_warnings
    def test_trace_checker_memoization(self):
        with self.assertRaisesRegex(
            torch.jit.TracingCheckError, r"Graphs differed across invocations!"
        ):

            def foo(x):
                if not hasattr(foo, "cache"):
                    foo.cache = torch.neg(x)
                return x + foo.cache

            traced = torch.jit.trace(
                foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]
            )

    def test_trace_checker_slice_lhs(self):
        def foo(x):
            for i in range(3):
                x[i, :] = torch.zeros(4)
            return x

        self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False)

    def test_trace_checker_inplace_on_view(self):
        def foo(x):
            x.view(-1).add_(-x.view(-1))
            return x

        with self.assertWarnsRegex(
            torch.jit.TracerWarning,
            "Output nr 1. of the traced function does not match the "
            "corresponding output of the Python function",
        ):
            torch.jit.trace(
                foo,
                torch.rand(3, 4),
                check_inputs=[torch.rand(5, 6)],
                _force_outplace=True,
            )

    def test_lhs_index_fails(self):
        def foo(x):
            x[0, 1] = 4
            return x

        with self.assertWarnsRegex(
            torch.jit.TracerWarning, "cause the trace to be incorrect"
        ):
            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)

    def test_lhs_index_trivial(self):
        def foo(y, x):
            y[...] = x
            return y

        self.checkTrace(
            foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False
        )

    def test_inplace_warn(self):
        def foo(x):
            x.view(-1).add_(-x.view(-1))
            return x

        with self.assertWarnsRegex(
            torch.jit.TracerWarning, "cause the trace to be incorrect"
        ):
            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)

    @suppress_warnings
    def test_trace_checker_dropout_train(self):
        def foo(x):
            return torch.dropout(x, p=0.5, train=True)

        with self.assertWarnsRegex(
            torch.jit.TracerWarning,
            "Output nr 1. of the traced function does not match the "
            "corresponding output of the Python function",
        ):
            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])

        with self.assertWarnsRegex(
            torch.jit.TracerWarning, "Trace had nondeterministic nodes"
        ):
            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])

    def test_trace_checker_dropout_notrain(self):
        input = torch.rand(3, 4)

        @_trace(input)
        def foo(x):
            return torch.dropout(x, p=0.5, train=False)

        self.assertEqual(foo(input), input)

    def test_trace_contiguous(self):
        def foo(x):
            return x[:, :, ::2].contiguous().view(12)

        x = torch.rand(2, 3, 4)
        traced = torch.jit.trace(foo, (x,))
        y = traced(x)
        self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())

    # This tests the logic in THPVariable_contiguous. There is short-circuiting
    # code that prevents us from even getting to VariableType::contiguous, since
    # it is an optimization that prevents us from acquiring the GIL for touching
    # the device. We needed to add the tracing logic directly into the
    # THPVariable_contiguous function only for the path where we are skipping
    # dispatch into contiguous. We should see an aten::contiguous in this trace!
    def test_trace_contiguous_short_circuit(self):
        def foo(x):
            return x.contiguous()

        x = torch.rand(2, 3, 4)
        traced = torch.jit.trace(foo, (x,))
        FileCheck().check("aten::contiguous").run(str(traced.graph))

    def test_trace_inverse(self):
        def foo(x):
            return ~x

        foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8))
        eg = torch.zeros(3, dtype=torch.uint8)
        self.assertEqual(foo_traced(eg), foo(eg))

    def test_trace_modulelist(self):
        class MySubmod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(x)

        class MyMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.ml = torch.nn.ModuleList([MySubmod(), MySubmod()])

            def forward(self, x):
                for mod in self.ml:
                    x = mod(x)
                return x

        traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))

    def test_trace_fork_join_and_module(self):
        class MySubmod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(x), torch.neg(x)

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.ml = torch.nn.ModuleList([MySubmod() for i in range(2)])

            def forward(self, x):
                futs = []
                for i in range(2):
                    futs.append(torch.jit._fork(self.ml[i], x))

                results = []
                for i in range(2):
                    results.append(torch.jit._wait(futs[i])[0])

                return torch.stack(results)

        m = Mod()
        traced = torch.jit.trace(m, torch.rand(3, 4))

    def test_trace_invert_module_hierarchy(self):
        class MySubmod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.relu = torch.nn.ReLU()

            def forward(self, x):
                return self.relu(x), torch.neg(x)

        class MyFunctionalMod(torch.nn.Module):
            def forward(self, x, submod):
                return submod(x)

        class Mod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.sm = MySubmod()
                self.fm = MyFunctionalMod()

            def forward(self, x):
                return self.fm(x, self.sm)

        torch.jit.trace(Mod(), (torch.rand(3, 4),))

    @skipIfCrossRef
    def test_trace_records_names(self):
        def foo(bar, baz):
            baz = bar + 3
            quick_brown_fox = torch.neg(baz)
            for _ in range(20):
                yeet = quick_brown_fox - 3.14
            return yeet

        traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
        graph_str = str(traced.graph)
        assert "bar" in graph_str
        assert "baz" in graph_str
        assert "quick_brown_fox" in graph_str

    @skipIfTorchDynamo("Not a suitable test for TorchDynamo")
    def test_tracing_hooks(self):
        class Net(nn.Module):
            def forward(self, x):
                return x + x

        def test_hook(is_post_hook, hook, fc):
            n = Net()
            if is_post_hook:
                n.register_forward_hook(hook)
            else:
                n.register_forward_pre_hook(hook)

            module = torch.jit.trace(n, (torch.tensor(1.0),))

            eager_input = torch.tensor(1.0)
            eager_out = n(eager_input)

            fc.run(module.forward.graph)
            input = torch.tensor(1.0)
            output = module(input)

            self.assertEqual(input, eager_input)
            self.assertEqual(output, eager_out)

        def hook_no_return(mod, input, output):
            input[0].add_(1)
            output.sub_(1)

        fc = FileCheck().check("add(").check("add_(").check("sub_(")
        test_hook(True, hook_no_return, fc)

        def hook_return(mod, input, output):
            input[0].add_(1)
            return output - 3

        fc = FileCheck().check("add(").check("add_(").check("sub(")
        test_hook(True, hook_return, fc)

        b = torch.tensor(3.0)

        def captured_hook(mod, input, output):
            return output - b

        fc = FileCheck().check("add(").check("sub(")
        test_hook(True, captured_hook, fc)

        def pre_hook_no_ret(mod, input):
            input[0].add_(3)

        fc = FileCheck().check("add_(").check("add(")
        test_hook(False, pre_hook_no_ret, fc)

        def pre_hook_ret(mod, input):
            return input[0] - 4

        fc = FileCheck().check("sub(").check("add(")
        test_hook(False, pre_hook_ret, fc)

    def test_tracing_backward_hook_error(self):
        class Net(nn.Module):
            def forward(self, x):
                return x + x

        n = Net()

        def backward_hook(module, grad_input, grad_output):
            pass

        n.register_backward_hook(backward_hook)
        with self.assertRaisesRegex(Exception, "backward hooks assigned"):
            torch.jit.trace(n, (torch.tensor(1.0),))

    def test_tracing_multiple_methods(self):
        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 3)

            def forward(self, x):
                return self.conv(x)

            def weighted_kernel_sum(self, weight):
                return weight * self.conv.weight

        example_weight = torch.rand(1, 1, 3, 3)
        example_forward_input = torch.rand(1, 1, 3, 3)
        inputs = {
            "forward": example_forward_input,
            "weighted_kernel_sum": example_weight,
        }
        n = Net()
        module = torch.jit.trace_module(n, inputs)

        check_inputs = []
        for i in range(2):
            check_weight = torch.rand(1, 1, 3, 3)
            check_forward_input = torch.rand(1, 1, 3, 3)
            check_inputs.append(
                {"forward": check_forward_input, "weighted_kernel_sum": check_weight}
            )
        module = torch.jit.trace_module(
            n, inputs, check_trace=True, check_inputs=check_inputs
        )
        self.assertTrue(module._c._has_method("forward"))
        self.assertTrue(module._c._has_method("weighted_kernel_sum"))

        module = torch.jit.trace(n.forward, example_forward_input)
        module = torch.jit.trace(
            n.forward,
            example_forward_input,
            check_trace=True,
            check_inputs=[example_forward_input],
        )
        with self.assertRaisesRegex(
            AttributeError,
            "trace doesn't support compiling individual module's functions",
        ):
            module = torch.jit.trace(n.weighted_kernel_sum, inputs)

    def test_tensor_with_grad_as_constant(self):
        param = torch.randn(3).requires_grad_()
        x = torch.randn(3)

        def f(x):
            return x + param

        with self.assertRaisesRegex(
            RuntimeError, "Cannot insert a Tensor that requires grad as a constant"
        ):
            torch.jit.trace(f, x)

    def test_non_tensor_tracing(self):
        def f(x):
            return x + param  # noqa: F821

        with self.assertRaisesRegex(
            RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"
        ):
            torch.jit.trace(f, (1,))

    def test_trace_skip_none_submodule(self):
        class TestModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.submod = torch.nn.Linear(3, 4)
                self.submod = None

            def forward(self, inputs):
                return inputs

        m = TestModule()
        tm = torch.jit.trace(m, torch.tensor(1.0))
        self.assertFalse(hasattr(tm, "submod"))

    def test_trace_with_conditional_property(self):
        class Net(nn.Module):
            def __init__(self, attr=None):
                super().__init__()
                if attr is not None:
                    self._attr = attr
                self.attr_name = "_attr"

            @property
            def attr(self):
                return getattr(self, self.attr_name)

            def forward(self, x):
                return x

        x = torch.ones(1)
        torch.jit.trace(Net(), x)

    def test_trace_func_argument_names_captured(self):
        def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
            return first_arg + second_arg

        traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1)))
        FileCheck().check("first_arg").check_next("second_arg").run(
            str(traced_fn.graph)
        )

    def test_trace_partial_func_argument_names_captured(self):
        def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor:
            return first_arg + second_arg

        traced_fn = torch.jit.trace(fn, (torch.ones(1),))
        FileCheck().check("first_arg").check_not("second_arg").run(str(traced_fn.graph))

    def test_trace_module_argument_names_captured(self):
        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 3)

            def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor):
                return self.conv(first_arg) + second_arg

        m = TestModule()
        example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3))

        # Explicitly tracing module's forward method
        traced_module_forward = torch.jit.trace(m.forward, example_input)
        FileCheck().check("first_arg").check_next("second_arg").run(
            str(traced_module_forward.graph)
        )

        # Tracing module's directly
        traced_module = torch.jit.trace(m, example_input)
        FileCheck().check("first_arg").check_next("second_arg").run(
            str(traced_module.graph)
        )

    def test_trace_checking_with_deprecated_name(self):
        class MyClass(torch.nn.Module):
            def __init__(self) -> None:
                super(MyClass, self).__init__()

            def forward(self, x, y, **deprecated_arguments):
                if len(deprecated_arguments) > 0:
                    raise RuntimeError(
                        f"Got unexpected arguments: {deprecated_arguments}"
                    )
                return x + y

        model = MyClass()
        m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
        m3 = torch.jit.trace(
            model,
            example_kwarg_inputs={"x": torch.ones(1), "y": torch.ones(1)},
            strict=False,
        )

    def test_trace_with_tuple_tensor(self):
        class MyClass(torch.nn.Module):
            def __init__(self) -> None:
                super(MyClass, self).__init__()

            def forward(self, x, y):
                return x + y[0] + y[1]

        model = MyClass()
        traced_model = torch.jit.trace(
            model, (torch.ones(1), (torch.ones(1), torch.ones(1)))
        )
        input_dict = {
            "x": torch.tensor([2, 3]),
            "y": (torch.tensor([5, 6]), torch.tensor([7, 8])),
        }
        self.assertEqual(model(**input_dict), traced_model(**input_dict))
        traced_model = torch.jit.trace(
            model,
            example_kwarg_inputs={
                "x": torch.ones(1),
                "y": (torch.ones(1), torch.ones(1)),
            },
        )
        self.assertEqual(model(**input_dict), traced_model(**input_dict))

    def test_trace_no_duplicated_lifted_input_output(self):
        class Normalize(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.norm = nn.GroupNorm(num_groups=32, num_channels=32)

            def forward(self, x, y):
                if y is None:
                    y = x
                else:
                    y = self.norm(y)
                y = y * 2
                return y

        class G(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.norm = Normalize()

            def forward(self, x):
                A = self.norm(x, None)
                B = F.relu(A)
                return A, B

        class Net(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.g = G()
                self.norm_1 = Normalize()

            def forward(self, x):
                hs = self.g(x)
                A, B = hs
                h = self.norm_1(B, A)
                return h

        net = Net()
        net = net.eval()
        x = torch.randn(1, 32, 16, 16)
        traced = torch.jit.trace(net, x)
        FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph))


@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestMixTracingScripting(JitTestCase):
    def test_trace_script(self):
        @torch.jit.script
        def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
            return x[0] + x[1]

        @torch.jit.script
        def func2(x: List[Tensor]) -> Tensor:
            return x[0] + x[1]

        a = torch.randn(5)
        b = torch.randn(5)

        self.checkTrace(func1, ((a, b),))
        self.checkTrace(func2, ((a, b),))

        @torch.jit.script
        def func3(
            x: Tensor, method: str = "bilinear", align_corners: bool = True
        ) -> Tensor:
            hw = x.shape[2:4]
            return F.interpolate(x, hw, mode=method, align_corners=align_corners)

        inp = torch.rand(1, 3, 6, 6)
        self.checkTrace(func3, (inp,))

        @torch.jit.script
        def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
            if len(a) == 2:
                return x + 2
            else:
                return x

    def test_trace_mixed_by_script_with_dict_output(self):
        @torch.jit.script
        def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]:
            return {"foo": input + 1}

        class TraceModule(torch.nn.Module):
            def forward(self, input):
                dict = return_dict(input)
                return dict["foo"] + dict["foo"]

        x = torch.ones(1)
        tm = torch.jit.trace(TraceModule(), x)
        self.assertEqual(tm(x), x + 1 + x + 1)

    def test_trace_of_script(self):
        @torch.jit.script
        def foo(a, c):
            b = 0.0
            if bool(a == 0.0):
                b = 1.0
            return b + c

        a = torch.ones(1, dtype=torch.float)

        @_trace(torch.zeros(1, dtype=torch.float))
        def use(b):
            return foo(b - 1.0, a) + 1.0

        # test we propagated shapes through the function
        self.assertTrue("Dynamic" not in str(use.graph))

        self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
        self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))

    def test_trace_with_size(self):
        @_trace(torch.zeros(1, 1))
        def foo(x):
            return x + 1

        @torch.jit.script
        def bar(x):
            y = int(foo(x))
            if 1 == 1:
                y = 7
            return y + 1

        self.assertEqual(8, bar(torch.ones(1, 1)))

    def test_tracing_slicing(self):
        @_trace(torch.zeros(10))
        def foo_trace(x):
            return x[-5:-3]

        @torch.jit.script
        def foo_script(x):
            return x[-5:-3]

        def foo(x):
            return x[-5:-3]

        a = torch.arange(0, 8)
        b = torch.arange(0, 20)
        self.assertEqual(foo_trace(a), foo_script(a))
        self.assertEqual(foo_trace(a), foo(a))
        self.assertNotEqual(foo_trace(a), foo_trace(b))

    def test_tracing_indexing(self):
        @_trace(torch.zeros(10))
        def foo_trace(x):
            return x[-2]

        @torch.jit.script
        def foo_script(x):
            return x[-2]

        def foo(x):
            return x[-2]

        a = torch.arange(0, 8)
        b = torch.arange(0, 20)
        self.assertEqual(foo_script(a), foo_trace(a))
        self.assertEqual(foo_trace(a), foo(a))
        self.assertNotEqual(foo_trace(a), foo_trace(b))

    def test_trace_hierarchy(self):
        # Test that we preserve the module hierarchy for a ScriptModule
        # submodule during tracing

        class AnotherScriptMod(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(1, 2, 3))

            @torch.jit.script_method
            def bar(self):
                return torch.zeros(4, 5)

        class SomeScriptMod(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.asm = AnotherScriptMod()

            @torch.jit.script_method
            def foo(self):
                return torch.zeros(3, 4)

            @torch.jit.script_method
            def bar(self):
                return torch.zeros(4, 3)

        class TraceMe(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.ssm = SomeScriptMod()

            def forward(self, x):
                return self.ssm.bar() + x

        orig = TraceMe()
        traced = torch.jit.trace(orig, (torch.rand(4, 3),))
        # for each of these checks, check that *BOTH* the underlying
        # _C.ScriptModule object has the expected method/param, as well as the
        # Python object that wraps it.
        self.assertTrue(traced.ssm._c._has_method("foo"))
        self.assertTrue(hasattr(traced.ssm, "foo"))

        imported = self.getExportImportCopy(traced)

        self.assertTrue(imported.ssm._c._has_method("foo"))
        self.assertTrue(hasattr(imported.ssm, "foo"))

        self.assertTrue(imported.ssm.asm._c._has_method("bar"))
        self.assertTrue(hasattr(imported.ssm.asm, "bar"))

        self.assertTrue(hasattr(imported.ssm.asm, "param"))

    def test_trace_parameter(self):
        class Param(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.register_parameter("bias", nn.Parameter(torch.empty(4, 4)))

            def forward(self, x):
                return x

        class M3(torch.jit.ScriptModule):
            def __init__(self, model):
                super().__init__()
                self.traced = torch.jit.trace(model, (torch.rand(3, 3)))

            @torch.jit.script_method
            def forward(self, x):
                return self.traced(x)

        class M2(nn.Module):
            def __init__(self, model):
                super().__init__()
                self.module = M3(model)

            def forward(self, x):
                return self.module(x)

        class M1(torch.jit.ScriptModule):
            def __init__(self, model):
                super().__init__()
                self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))

            @torch.jit.script_method
            def forward(self, x):
                return self.traced(x)

        with torch.jit.optimized_execution(False):
            module = M1(Param())
            f = io.BytesIO()
            torch.jit.save(module, f)

    @_tmp_donotuse_dont_inline_everything
    def test_call_script_fn_from_traced_module(self):
        @torch.jit.script
        def scripted_fn(x):
            return torch.neg(x)

        class TracedModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 5))

            def forward(self, x):
                return scripted_fn(torch.mm(x, self.param))

        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
        FileCheck().check("aten::mm").check('name="scripted_fn"').check(
            "prim::CallFunction"
        ).run(str(tm.graph))

    @_tmp_donotuse_dont_inline_everything
    def test_call_script_module_from_traced_module(self):
        class ScriptMod(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.param_foo = torch.nn.Parameter(torch.rand(5, 7))

            @torch.jit.script_method
            def forward(self, x):
                return torch.mm(x, self.param_foo)

        class TracedModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 5))
                self.mod = ScriptMod()

            def forward(self, x):
                return self.mod(torch.mm(x, self.param)) + 1.0

        tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))

        FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
            "forward"
        ).check("aten::add").run(str(tm.graph))

    @_tmp_donotuse_dont_inline_everything
    def test_call_traced_fn_from_script_fn(self):
        @_trace(torch.rand(3, 4))
        def traced_fn(x):
            return torch.neg(x)

        @torch.jit.script
        def script_fn(x):
            return traced_fn(x) + 1

        FileCheck().check("prim::CallFunction").check("aten::add").run(
            str(script_fn.graph)
        )

    def test_call_traced_mod_from_script_fn(self):
        with self.assertRaisesRegex(
            RuntimeError,
            "Cannot call a ScriptModule that is not a submodule of the caller",
        ):

            class TracedModule(torch.nn.Module):
                def forward(self, x):
                    return torch.mm(x, torch.zeros(4, 3))

            tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))

            @torch.jit.script
            def script_fn(x):
                return tm(x) + 1

    @_tmp_donotuse_dont_inline_everything
    def test_call_tracing_fn_from_script_module(self):
        @_trace(torch.rand(3, 3))
        def traced_fn(x):
            return torch.neg(x)

        class ScriptMod(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 3))

            @torch.jit.script_method
            def forward(self, x):
                return traced_fn(torch.mm(x, self.param))

        sm = ScriptMod()
        FileCheck().check("aten::mm").check("prim::CallFunction").run(
            str(sm.forward.graph)
        )

    @_tmp_donotuse_dont_inline_everything
    def test_call_tracing_mod_from_script_module(self):
        class TracedMod(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(3, 5))

            def forward(self, x):
                return torch.mm(x, self.param)

        class ScriptMod(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.rand(4, 3))
                self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))

            @torch.jit.script_method
            def forward(self, x):
                return self.tm(torch.mm(x, self.param))

        sm = ScriptMod()
        FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph))

    def test_script_inline_trace_multiple_args(self):
        class M(torch.nn.Module):
            def forward(self, input, input2):
                return input + input2

        class M2(torch.jit.ScriptModule):
            def __init__(self) -> None:
                super().__init__()
                self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))

            @torch.jit.script_method
            def forward(self, inp):
                return self.m(inp, inp)

        with torch.jit.optimized_execution(False):
            m2 = M2()
            m2(torch.zeros(4, 3))

    def test_trace_dict_mix_script(self):
        class testB(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(2, 2)

            def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor:
                output = []
                for j in feature_map.values():
                    output.append(self.linear(j[0]))

                return torch.stack(output)

        class testA(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.b = torch.jit.script(testB())

            def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor:
                feature_map = {}
                for i, j in input_map.items():
                    feature_map[i] = [j[0]]

                return self.b(feature_map)

        input_map = {
            "1": [torch.rand(2, 2), torch.rand(2, 2)],
            "3": [torch.rand(2, 2), torch.rand(2, 2)],
        }
        model = testA()
        traced_model = torch.jit.trace(model, input_map)
        new_input_map = {
            "1": [torch.rand(2, 2), torch.randn(2, 2)],
            "3": [torch.rand(2, 2), torch.rand(2, 2)],
        }
        self.assertEqual(model(new_input_map), traced_model(new_input_map))

    def test_trace_script_returning_complex_dict(self):
        """Tracing over a script function returning a dictionary should work.
        The dictionary can should be able to contain other containers (like a tuple) recursively.
        """

        class ReturnsDict(torch.nn.Module):
            def forward(
                self,
                id_score_list: Dict[
                    str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
                ],
            ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
                # do some random operations and then return a dict of the same structure
                v = id_score_list["1000"]
                idx_keys = v[1] - 1500000
                weights = v[2]
                result = {"1000": (v[0], idx_keys, weights)}
                return result

        class ChecksDict(torch.nn.Module):
            def forward(
                self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
            ):
                v = input["1000"]
                return v[1] + 1

        class TestModule(torch.nn.Module):
            def __init__(self, checks_dict, returns_dict):
                super().__init__()
                self.checks_dict = checks_dict
                self.returns_dict = returns_dict

            def forward(
                self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
            ):
                foo = self.returns_dict(input)
                return self.checks_dict(foo)

        input1 = {
            "1000": (
                torch.tensor([0]),
                torch.tensor([], dtype=torch.int64),
                torch.tensor([]),
            )
        }

        input2 = {
            "1000": (
                torch.tensor([0]),
                torch.tensor([1500000, 1500004], dtype=torch.int64),
                torch.tensor([2.0, 3.0]),
            )
        }

        checks_dict = torch.jit.script(ChecksDict())
        returns_dict = torch.jit.script(ReturnsDict())
        eager_module = TestModule(checks_dict, returns_dict)
        traced_module = torch.jit.trace(eager_module, input1)
        self.assertEqual(traced_module(input1), eager_module(input1))
        self.assertEqual(traced_module(input2), eager_module(input2))

    def test_trace_returning_dict_with_tensor_tuples(self):
        """Tracing over a module returning a dictionary whose values are tuples of tensors
        should work.
        """

        class ReturnsDict(torch.nn.Module):
            def forward(
                self, k: torch.Tensor, v: torch.Tensor
            ) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
                x = 2 * k
                y = 3 * v
                result = {"imakey": (x, y)}
                return result

        class ReturnsBadDict(torch.nn.Module):
            def forward(
                self, k: torch.Tensor, v: torch.Tensor
            ) -> Dict[str, Tuple[torch.Tensor, float]]:
                x = 2 * k
                result = {"imakey": (x, 1)}
                return result

        mod = ReturnsDict()
        traced_module = torch.jit.trace(
            mod, [torch.ones(1), torch.ones(1)], strict=False
        )
        out = traced_module(torch.ones(1), torch.ones(1))
        expected = {"imakey": (torch.tensor([2.0]), torch.tensor([3.0]))}
        self.assertEqual(out, expected)

        with self.assertRaisesRegex(
            RuntimeError, "cannot be understood by the tracer, only outputs matching"
        ):
            mod = ReturnsBadDict()
            traced_module = torch.jit.trace(
                mod, [torch.ones(1), torch.ones(1)], strict=False
            )

    def test_trace_linear(self):
        m = torch.nn.Linear(20, 20)
        inp = torch.rand([20, 20])
        self.checkTrace(m, (inp,))
        g = torch.jit.trace(m, (inp,)).graph
        FileCheck().check("aten::linear").run(g)

    def test_traced_module_implements_interface(self):
        @torch.jit.interface
        class TestModuleInterface(nn.Module):
            def forward(
                self, first_arg: torch.Tensor, second_arg: torch.Tensor
            ) -> torch.Tensor:
                pass

        make_global(TestModuleInterface)

        class TestModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.conv = nn.Conv2d(1, 1, 3)

            def forward(
                self, first_arg: torch.Tensor, second_arg: torch.Tensor
            ) -> torch.Tensor:
                return self.conv(first_arg) + second_arg

        def fn_takes_interface(x: TestModuleInterface):
            ones = torch.ones(1, 1, 3, 3)
            return x.forward(ones, ones)

        scripted_test_module = torch.jit.script(TestModule())
        self.checkScript(fn_takes_interface, (scripted_test_module,))

    def test_traced_module_contains_scripted_interface_types(self):
        class LeafModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = torch.nn.Parameter(torch.rand(19))

            def forward(self, input: torch.Tensor):
                return input + self.weight

        class LowerModuleImpl(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.leaf = LeafModule()

            def forward(self, input: torch.Tensor) -> torch.Tensor:
                return self.leaf(input)

        @torch.jit.interface
        class LowerModuleInterface(torch.nn.Module):
            def forward(self, input: torch.Tensor) -> torch.Tensor:
                pass

        class MiddleModule(torch.nn.Module):
            lower: LowerModuleInterface

            def __init__(self, feature_processor_modules=None):
                super().__init__()
                self.lower = LowerModuleImpl()

            def forward(self, input):
                return self.lower(input)

        class WrapperModule(torch.nn.Module):
            def __init__(self, m):
                super().__init__()
                self.middle = m

            def forward(self, input):
                return self.middle(input)

        class TopModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                m = MiddleModule()
                m = torch.jit.script(m)
                self.sub1 = m
                self.sub2 = WrapperModule(m)

            def forward(self, input: torch.Tensor):
                return self.sub1(input) + self.sub2(input)

        top = TopModule()
        top_example_input = torch.ones(1)
        torch.jit.trace(top, top_example_input)

    def test_jit_trace_callfunction_return_shapes(self):
        # a torch.jit.script function gets inserted as a CallFunction node
        @torch.jit.script
        def inner_fn(x):
            return torch.cat((x, x))

        def outer_fn(x, y):
            return inner_fn(x + y).relu()

        x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)]
        fn_t = torch.jit.trace(outer_fn, (x, y))

        # expect that the CallFunction node return type has shape information on it.
        FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph)
        for n in fn_t.graph.nodes():
            if n.kind() == "prim::CallFunction":
                self.assertTrue(n.output().isCompleteTensor())
