# Owner(s): ["NNC"]

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import unittest
import itertools

from torch.testing._internal.common_utils import suppress_warnings, num_profiled_runs, run_tests, skipIfTorchDynamo

from torch.testing._internal.jit_utils import JitTestCase, TensorExprTestOptions

LLVM_ENABLED = torch._C._llvm_enabled()

class BaseTestClass(JitTestCase):
    def setUp(self):
        super().setUp()
        self.tensorexpr_options = TensorExprTestOptions()
        self.devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']
        self.dtypes = [torch.float32, torch.bfloat16] if LLVM_ENABLED else [torch.float32]

    def tearDown(self):
        self.tensorexpr_options.restore()
        super().tearDown()

    def assertLastGraphAllFused(self):
        self.assertAllFused(torch.jit.last_executed_optimized_graph())


def warmup_and_run_forward(f, *args):
    for _ in range(torch._C._jit_get_num_profiled_runs() + 1):
        results = f(*args)
    return results


@skipIfTorchDynamo()
class TestTensorExprFuser(BaseTestClass):
    def test_easy(self):
        def easy(x, y):
            aaa = torch.add(x, y)
            return aaa

        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))

        a = torch.rand(1024)
        b = torch.rand(1024)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())

    def test_three_arg(self):
        def easy(x, y, z):
            aaa = torch.add(x, y)
            bbb = torch.add(aaa, z)
            return bbb

        traced = torch.jit.trace(
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
        )

        a = torch.rand(1024)
        b = torch.rand(1024)
        c = torch.rand(1024)
        x = warmup_and_run_forward(traced, a, b, c)
        self.assertLastGraphAllFused()
        npr = a.numpy() + b.numpy() + c.numpy()
        np.testing.assert_allclose(npr, x.numpy())

    def test_four_arg(self):
        def run_addcmul(x, y, z, w):
            c = torch.addcmul(torch.add(x, y), z, w)
            return c

        for dev in self.devices:
            rand_a = torch.rand(1024, dtype=torch.float, device=dev)
            rand_b = torch.rand(1024, dtype=torch.float, device=dev)
            rand_c = torch.rand(1024, dtype=torch.float, device=dev)
            rand_d = torch.rand(1024, dtype=torch.float, device=dev)

            traced = torch.jit.trace(
                run_addcmul,
                (
                    torch.zeros(1024, dtype=torch.float, device=dev),
                    torch.zeros(1024, dtype=torch.float, device=dev),
                    torch.zeros(1024, dtype=torch.float, device=dev),
                    torch.zeros(1024, dtype=torch.float, device=dev),
                ),
            )

            x = warmup_and_run_forward(traced, rand_a, rand_b, rand_c, rand_d)
            self.assertLastGraphAllFused()
            y = run_addcmul(rand_a, rand_b, rand_c, rand_d)
            np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy(), atol=1e-6)

    def test_three_arg2(self):
        for device in self.devices:
            def test(x, y, z):
                aaa = torch.add(x, y)
                bbb = torch.add(aaa, z)
                return bbb

            M = 32
            N = 32
            traced = torch.jit.trace(
                test,
                (
                    torch.rand(M, N, device=device),
                    torch.rand(M, N, device=device),
                    torch.rand(M, N, device=device),
                ),
            )

            a = torch.rand(M, N, device=device)
            b = torch.rand(M, N, device=device)
            c = torch.rand(M, N, device=device)
            x = traced(a, b, c)
            x = warmup_and_run_forward(traced, a, b, c)
            self.assertLastGraphAllFused()
            npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
            np.testing.assert_allclose(npr, x.cpu().numpy())

    def test_broadcast3(self):
        for device in self.devices:
            def test_body(M, N, L, K):
                def test(x, y, z):
                    v1 = torch.add(x, y)
                    v2 = torch.add(v1, z)
                    return v2

                a_shape = [M, N]
                b_shape = [L, M, 1]
                c_shape = [K, L, 1, 1]
                traced = torch.jit.trace(
                    test,
                    (
                        torch.rand(*a_shape, device=device),
                        torch.rand(*b_shape, device=device),
                        torch.rand(*c_shape, device=device),
                    ),
                )

                a = torch.rand(*a_shape, device=device)
                b = torch.rand(*b_shape, device=device)
                c = torch.rand(*c_shape, device=device)
                x = warmup_and_run_forward(traced, a, b, c)
                self.assertLastGraphAllFused()
                npr = a.cpu().numpy() + b.cpu().numpy() + c.cpu().numpy()
                np.testing.assert_allclose(npr, x.cpu().numpy())

            test_configs = [[5, 2, 7, 3], [8, 8, 8, 8]]
            for test_config in test_configs:
                test_body(*test_config)

    def test_all_combos(self):
        def easy(x, y, z):
            a = torch.add(x, y)
            b = torch.add(a, z)
            c = torch.add(x, b)
            d = torch.add(c, a)
            return d

        def np_easy(x, y, z):
            a = x + y
            b = a + z
            c = x + b
            d = c + a
            return d

        traced = torch.jit.trace(
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
        )

        a = torch.rand(1024)
        b = torch.rand(1024)
        c = torch.rand(1024)
        x = warmup_and_run_forward(traced, a, b, c)
        self.assertLastGraphAllFused()
        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
        np.testing.assert_allclose(npr, x.numpy())

    def test_rank_two(self):
        def easy(x, y, z):
            a = torch.add(x, y)
            b = torch.add(a, z)
            c = torch.add(x, b)
            d = torch.add(c, a)
            return d

        def np_easy(x, y, z):
            a = x + y
            b = a + z
            c = x + b
            d = c + a
            return d

        shape = 32, 32
        traced = torch.jit.trace(
            easy, (torch.rand(shape), torch.rand(shape), torch.rand(shape))
        )

        a = torch.rand(shape)
        b = torch.rand(shape)
        c = torch.rand(shape)
        x = warmup_and_run_forward(traced, a, b, c)
        self.assertLastGraphAllFused()
        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
        np.testing.assert_allclose(npr, x.numpy())

    def test_broadcast(self):
        def easy(x, y, z):
            a = torch.add(x, y)
            b = torch.add(a, z)
            return b

        def np_easy(x, y, z):
            a = x + y
            b = a + z
            return b

        N = 32
        traced = torch.jit.trace(easy, (torch.rand(N, N), torch.rand(N), torch.rand(N, N)))

        a = torch.rand(N, N)
        b = torch.rand(N)
        c = torch.rand(N, N)
        x = warmup_and_run_forward(traced, a, b, c)
        self.assertLastGraphAllFused()
        npr = np_easy(a.numpy(), b.numpy(), c.numpy())
        np.testing.assert_allclose(npr, x.numpy())

    def test_broadcast_2(self):
        zero = torch.tensor([0.0], dtype=torch.float)

        def foo(x, y, z):
            aaa = torch.add(x, y)
            bbb = torch.add(zero, aaa)
            return torch.add(bbb, z)

        def foo_np(x, y, z):
            a = x + y
            b = zero.numpy() + a
            return b + z

        x = torch.rand(3, 4)
        y = torch.ones(3, 1)
        z = torch.rand(4)
        traced = torch.jit.trace(foo, (x, y, z))

        r = warmup_and_run_forward(traced, x, y, z)
        self.assertLastGraphAllFused()

        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
        np.testing.assert_allclose(r, rnp)

    def test_broadcast_big2(self):
        zero = torch.tensor([0.0], dtype=torch.float)

        def foo(x, y, z):
            aaa = torch.add(x, y)
            bbb = torch.add(zero, aaa)
            return torch.add(bbb, z)

        def foo_np(x, y, z):
            a = x + y
            b = zero.numpy() + a
            return b + z

        x = torch.rand(32, 1024)
        y = torch.ones(32, 1)
        z = torch.rand(1024)
        traced = torch.jit.trace(foo, (x, y, z))

        r = warmup_and_run_forward(traced, x, y, z)
        self.assertLastGraphAllFused()
        rnp = foo_np(x.numpy(), y.numpy(), z.numpy())
        np.testing.assert_allclose(r, rnp)

    def test_alpha(self):
        def alpha(x):
            aaa = torch.add(x, x, alpha=2.0)
            return aaa

        traced = torch.jit.trace(alpha, (torch.tensor([1.0])))

        a = torch.tensor([1.0])
        x = traced(a)
        np.testing.assert_allclose(a.numpy() + 2.0 * a.numpy(), x.numpy())

    @suppress_warnings
    def test_constant(self):
        def constant(x):
            bbb = torch.tensor([1.0])
            aaa = torch.add(x, bbb)
            return aaa

        traced = torch.jit.trace(constant, (torch.tensor([1.0])))

        a = torch.tensor([1.0])
        x = warmup_and_run_forward(traced, a)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(a.numpy() + 1.0, x.numpy())

    def test_add_sub(self):
        def easy(x, y, z):
            aaa = torch.add(x, y)
            bbb = torch.sub(aaa, z)
            return bbb

        traced = torch.jit.trace(
            easy, (torch.rand(1024), torch.rand(1024), torch.rand(1024))
        )

        a = torch.rand(1024)
        b = torch.rand(1024)
        c = torch.rand(1024)
        x = warmup_and_run_forward(traced, a, b, c)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(a.numpy() + b.numpy() - c.numpy(), x.numpy())

    def test_promotion(self):
        def easy(x, y):
            aaa = torch.add(x, y)
            return aaa

        traced = torch.jit.trace(
            easy,
            (torch.zeros(1024, dtype=torch.int32), torch.rand(1024, dtype=torch.float32)),
        )

        a = torch.zeros(1024, dtype=torch.int32)
        b = torch.rand(1024, dtype=torch.float32)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(a.numpy() + b.numpy(), x.numpy())

    def test_double(self):
        TENSOR_LEN = 8

        def easy(x, y):
            aaa = torch.add(x, y)
            bbb = torch.mul(aaa, y)
            return bbb

        traced = torch.jit.trace(
            easy,
            (torch.rand(TENSOR_LEN, dtype=torch.float64), torch.full((TENSOR_LEN,), 0.5, dtype=torch.float64)),
        )

        a = torch.rand(TENSOR_LEN, dtype=torch.double)
        b = torch.full((TENSOR_LEN,), 0.5, dtype=torch.double)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())

    def test_short(self):
        TENSOR_LEN = 8

        def easy(x, y):
            aaa = torch.add(x, y)
            bbb = torch.mul(aaa, y)
            return bbb

        traced = torch.jit.trace(
            easy,
            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16),
             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)),
        )

        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int16)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())

    def test_char(self):
        TENSOR_LEN = 8

        def easy(x, y):
            aaa = torch.add(x, y)
            bbb = torch.mul(aaa, y)
            return bbb

        traced = torch.jit.trace(
            easy,
            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)),
        )

        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())

    def test_int64_promotion(self):
        TENSOR_LEN = 8

        def easy(x, y):
            aaa = torch.add(x, y)
            bbb = torch.mul(aaa, y)
            return bbb

        traced = torch.jit.trace(
            easy,
            (torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8),
             torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)),
        )

        a = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int8)
        b = torch.randint(TENSOR_LEN, (TENSOR_LEN,), dtype=torch.int64)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose((a.numpy() + b.numpy()) * b.numpy(), x.numpy())

    def test_eq(self):
        def easy(x, y):
            c = torch.eq(x, y)
            return c

        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
        a = torch.zeros(1024, dtype=torch.int32)
        b = torch.zeros(1024, dtype=torch.int32)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(np.ones(1024), x.numpy())

    def test_ne(self):
        def easy(x, y):
            c = torch.ne(x, y)
            return c

        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
        a = torch.zeros(1024, dtype=torch.int32)
        b = torch.ones(1024, dtype=torch.int32)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(np.ones(1024), x.numpy())

    def test_ge(self):
        def easy(x, y):
            c = torch.ge(x, y)
            return c

        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
        aa = np.empty([1024], dtype=np.int32)
        aa.fill(5)
        a = torch.from_numpy(aa)
        b = torch.zeros(1024, dtype=torch.int32)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(np.ones(1024), x.numpy())

    def test_gt(self):
        def easy(x, y):
            c = torch.gt(x, y)
            return c

        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
        a = torch.ones(1024, dtype=torch.int32)
        b = torch.zeros(1024, dtype=torch.int32)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(np.ones(1024), x.numpy())

    def test_le(self):
        def easy(x, y):
            c = torch.le(x, y)
            return c

        traced = torch.jit.trace(easy, (torch.zeros(1024), torch.zeros(1024)))
        aa = np.empty([1024], dtype=np.int32)
        aa.fill(5)
        a = torch.from_numpy(aa)
        b = torch.zeros(1024, dtype=torch.int32)
        x = warmup_and_run_forward(traced, a, b)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(np.zeros(1024), x.numpy())

    def test_lt(self):
        def easy(x, y):
            c = torch.lt(x, y)
            return c

        for dev in self.devices:
            traced = torch.jit.trace(easy, (torch.zeros(1024, device=dev), torch.zeros(1024, device=dev)))
            a = torch.ones(1024, dtype=torch.int32, device=dev)
            b = torch.zeros(1024, dtype=torch.int32, device=dev)
            x = warmup_and_run_forward(traced, a, b)
            self.assertLastGraphAllFused()
            np.testing.assert_allclose(np.zeros(1024), x.cpu().numpy())

    @suppress_warnings
    def test_min_max(self):
        def test(x, y):
            return torch.max(torch.min(x, y), torch.tensor([4.0]))

        traced = torch.jit.trace(test, (torch.zeros(1024), torch.zeros(1024)))
        a = 8.0 * torch.rand(1024)
        b = 8.0 * torch.rand(1024)
        np.testing.assert_allclose(
            warmup_and_run_forward(traced, a, b), np.maximum(np.minimum(a.numpy(), b.numpy()), [4.0])
        )
        self.assertLastGraphAllFused()

    def test_min_max_reduction(self):
        def test(x):
            return torch.min(x) + torch.max(x)

        traced = torch.jit.trace(test, (torch.zeros(1024)))
        a = 8.0 * torch.rand(1024)
        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
        self.assertLastGraphAllFused()

    def test_min_max_reduction2(self):
        def test(x):
            return x.min() + x.max()

        traced = torch.jit.trace(test, (torch.zeros(1024)))
        a = 8.0 * torch.rand(1024)
        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(a.numpy()) + np.amax(a.numpy()))
        self.assertLastGraphAllFused()

    def test_min_max_reduction_dim1(self):
        def test(x):
            return torch.min(x, 1)[0] + torch.max(x, 1)[0]

        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
        a = 8.0 * torch.rand(16, 16)
        np.testing.assert_allclose(warmup_and_run_forward(traced, a), np.amin(
            a.numpy(), axis=1) + np.amax(a.numpy(), axis=1))
        self.assertLastGraphAllFused()

    def test_min_max_reduction_dim1_2(self):
        def test(x):
            return torch.min(x * x, 1)

        traced = torch.jit.trace(test, (torch.zeros(16, 16)))
        a = 8.0 * torch.rand(16, 16)
        np.testing.assert_allclose(warmup_and_run_forward(traced, a)[0], np.amin((a * a).numpy(), axis=1))
        self.assertLastGraphAllFused()

    def test_clamp(self):
        def test(x):
            return torch.clamp(x + 3.0, 0.0, 6.0)

        for dev in self.devices:
            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
            a = 20.0 * torch.rand(1024, device=dev) - 10.0
            an = a.cpu().numpy()
            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip(an + 3.0, 0.0, 6.0))
            self.assertLastGraphAllFused()

    def test_relu(self):
        def test(x):
            return torch.clamp(F.relu(x), 0, 0.5)

        for dev in self.devices:
            traced = torch.jit.trace(test, (torch.zeros(1024, device=dev)))
            a = 20.0 * torch.rand(1024, device=dev) - 10.0
            an = a.cpu().numpy()
            np.testing.assert_allclose(warmup_and_run_forward(traced, a).cpu(), np.clip((np.maximum(0, an)), 0, 0.5))
            self.assertLastGraphAllFused()

    def test_reps(self):
        def easy(x, y):
            c = torch.add(x, y)
            return c

        traced = torch.jit.trace(easy, (torch.rand(1024), torch.rand(1024)))

        for _ in range(32):
            a = torch.ones(1024)
            b = torch.zeros(1024)
            x = warmup_and_run_forward(traced, a, b)
            np.testing.assert_allclose(np.ones(1024), x.numpy())

    def test_add_const_rhs(self):
        def test(x):
            return x + 3.0

        traced = torch.jit.trace(test, torch.rand(4))
        x = torch.rand(4)
        y = warmup_and_run_forward(traced, x)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(x.numpy() + 3.0, y.numpy())

    def test_int_output(self):
        def test(x, y, z):
            return x * y * z

        xs = [(torch.rand(4) * 3 + 1).to(torch.int32) for i in range(3)]
        x, y, z = xs
        xn, yn, zn = (t.numpy() for t in xs)
        traced = torch.jit.trace(test, (x, y, z))
        res = warmup_and_run_forward(traced, x, y, z)
        self.assertLastGraphAllFused()
        np.testing.assert_allclose(xn * yn * zn, res.numpy())

    def test_binary_ops(self):
        def test_atan2(x, y):
            c = torch.atan2(torch.add(x, y), y)
            return c

        def test_gt(x, y):
            c = torch.gt(torch.add(x, y), y)
            return c

        def test_ge(x, y):
            c = torch.ge(torch.add(x, y), y)
            return c

        def test_lt(x, y):
            c = torch.lt(torch.add(x, y), y)
            return c

        def test_le(x, y):
            c = torch.le(torch.add(x, y), y)
            return c

        def test_lerp(x, y):
            c = torch.lerp(torch.add(x, 1), x, 2.0)
            return c

        def test_mul(x, y):
            c = torch.mul(torch.add(x, y), y)
            return c

        def test_ne(x, y):
            c = torch.ne(torch.add(x, y), y)
            return c

        def test_div(x, y):
            c = torch.div(torch.add(x, y), 2)
            return c

        def test_eq(x, y):
            c = torch.eq(torch.add(x, y), y)
            return c

        def test_fmod(x, y):
            c = torch.fmod(torch.add(x, y), 2)
            return c

        def test_sub(x, y):
            c = torch.sub(torch.add(x, y), x)
            return c

        def test_remainder(x, y):
            c = torch.remainder(torch.add(x, y), 3.0)
            return c

        def test_pow(x, y):
            c = torch.pow(torch.add(x, y), 2.0)
            return c

        def test_type_as(x, y):
            return x.type_as(torch.add(x, y))

        cmp_fns = {
            test_gt,
            test_ge,
            test_lt,
            test_le,
            test_ne,
            test_eq
        }

        non_cmp_fns = {
            test_atan2,
            test_lerp,
            test_mul,
            test_div,
            test_fmod,
            test_sub,
            test_remainder,
            test_pow,
            test_type_as,
        }

        all_test_fns = cmp_fns.union(non_cmp_fns)
        fn_dev_dtype = itertools.product(all_test_fns, self.devices, self.dtypes)
        for torch_fn, dev, data_type in fn_dev_dtype:
            if torch_fn is test_lerp and data_type is torch.bfloat16:
                continue
            rand_a = torch.rand(1024, dtype=data_type, device=dev)
            rand_b = torch.rand(1024, dtype=data_type, device=dev)
            in1 = 20 * torch.rand(1024, dtype=data_type, device=dev)
            in2 = 20 * torch.rand(1024, dtype=data_type, device=dev)
            traced = torch.jit.trace(torch_fn, (in1, in2))
            x = warmup_and_run_forward(traced, rand_a, rand_b)
            self.assertLastGraphAllFused()

            _atol = 2e-3
            _rtol = 1e-5
            if data_type is torch.bfloat16:
                # Compared to aten logic, NNC coudl save addtional BF16/Fp32 conversion.
                # Take d = a + b - c as an example, the aten logic is as follows at
                # operator level:
                #    tmp = to_bf16(to_fp32(a) + to_fp32(b))
                #    d = to_bf16(to_fp32(tmp) + to_fp32(c))
                # But NNC could fuse the compression and remove the redudant conversions.
                # The final statement is as follows
                #    d = to_bf16(to_fp32(a) + to_fp32(b) + to_fp32(c))
                # Hence, we simulate NNC computation by feeding fp32 tensors and converting
                # the result tensor back to bf16. The simulation could avoid the numeric
                # deviation to simplify the result comprasion
                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
                if torch_fn not in cmp_fns:
                    y = y.bfloat16()
                _atol = 2e-2
            else:
                y = torch_fn(rand_a, rand_b)
            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)

    def test_unary_ops(self):
        def test_cast_float(x, y):
            c = torch.ops.aten._cast_Float(torch.add(x, y))
            return c

        def test_round(x, y):
            c = torch.round(torch.add(x, y))
            return c

        def test_sin(x, y):
            c = torch.sin(torch.add(x, y))
            return c

        def test_asin(x, y):
            c = torch.asin(torch.add(x, y))
            return c

        def test_sinh(x, y):
            c = torch.sinh(torch.add(x, y))
            return c

        def test_cos(x, y):
            c = torch.cos(torch.add(x, y))
            return c

        def test_acos(x, y):
            c = torch.acos(torch.add(x, y))
            return c

        def test_cosh(x, y):
            c = torch.cosh(torch.add(x, y))
            return c

        def test_tan(x, y):
            c = torch.tan(torch.add(x, y))
            return c

        def test_atan(x, y):
            c = torch.atan(torch.add(x, y))
            return c

        def test_tanh(x, y):
            c = torch.tanh(torch.add(x, y))
            return c

        def test_sqrt(x, y):
            c = torch.sqrt(torch.add(x, y))
            return c

        def test_rsqrt(x, y):
            c = torch.rsqrt(torch.add(x, y))
            return c

        def test_floor(x, y):
            c = torch.floor(torch.add(x, y))
            return c

        def test_ceil(x, y):
            c = torch.ceil(torch.add(x, y))
            return c

        def test_trunc(x, y):
            c = torch.trunc(torch.add(x, y))
            return c

        def test_abs(x, y):
            c = torch.abs(torch.add(x, y))
            return c

        def test_log(x, y):
            c = torch.log(torch.add(x, y))
            return c

        def test_log2(x, y):
            c = torch.log2(torch.add(x, y))
            return c

        def test_log10(x, y):
            c = torch.log10(torch.add(x, y))
            return c

        def test_log1p(x, y):
            c = torch.log1p(torch.add(x, y))
            return c

        def test_rqrt(x, y):
            c = torch.rsqrt(torch.add(x, y))
            return c

        def test_erf(x, y):
            c = torch.erf(torch.add(x, y))
            return c

        def test_exp(x, y):
            c = torch.exp(torch.add(x, y))
            return c

        def test_expm1(x, y):
            c = torch.expm1(torch.add(x, y))
            return c

        def test_erfc(x, y):
            c = torch.erfc(torch.add(x, y))
            return c

        def test_frac(x, y):
            c = torch.frac(torch.add(x, y))
            return c

        def test_lgamma(x, y):
            c = torch.lgamma(torch.add(x, y))
            return c

        def test_sigmoid(x, y):
            c = torch.sigmoid(torch.add(x, y))
            return c

        def test_reciprocal(x, y):
            c = torch.reciprocal(torch.add(x, y))
            return c

        def test_neg(x, y):
            c = torch.neg(torch.add(x, y))
            return c

        def test_relu(x, y):
            c = torch.relu(torch.add(x, y))
            return c

        def test_hardtanh(x, y):
            c = F.hardtanh(torch.add(x, y), -1.0, 1.0)
            return c

        def test_threshold(x, y):
            c = F.threshold(torch.add(x, y), 0.5, 10)
            return c

        gpu_only_fns = {
            test_erf,
            test_erfc
        }
        fns = {
            test_round,
            test_sin,
            test_asin,
            test_sinh,
            test_cos,
            test_acos,
            test_cosh,
            test_tan,
            test_atan,
            test_sqrt,
            test_floor,
            test_ceil,
            test_trunc,
            test_abs,
            test_log,
            test_log2,
            test_log10,
            test_log1p,
            test_rsqrt,
            test_exp,
            test_expm1,
            test_frac,
            test_lgamma,
            test_reciprocal,
            test_neg,
            test_threshold,
            test_relu,
            test_tanh,
            test_hardtanh,
            test_sigmoid,
        }
        fn_dev_dtype = itertools.product(gpu_only_fns.union(fns), self.devices, self.dtypes)

        torch.manual_seed(0)
        for torch_fn, dev, data_type in fn_dev_dtype:
            if torch_fn == test_lgamma and dev == "cuda":
                # lgamma_cuda does not support BF16
                continue
            rand_a = torch.rand(1024, dtype=data_type, device=dev)
            rand_b = torch.rand(1024, dtype=data_type, device=dev)

            ins = 20 * torch.rand(1024, dtype=data_type, device=dev)
            cc = np.empty([1024], dtype=np.float32)
            cc.fill(np.nan)
            nans = torch.from_numpy(cc).to(dev)
            traced = torch.jit.trace(torch_fn, (ins, ins))
            x = warmup_and_run_forward(traced, rand_a, rand_b)
            self.assertLastGraphAllFused()

            _atol = 5e-3 if data_type is torch.bfloat16 else 2e-3
            _rtol = 1e-5
            if data_type is torch.bfloat16 and torch_fn not in gpu_only_fns:
                y = warmup_and_run_forward(traced, rand_a.float(), rand_b.float())
                y = y.bfloat16()
            else:
                y = torch_fn(rand_a, rand_b)

            self.assertEqual(x.cpu(), y.cpu(), atol=_atol, rtol=_rtol)
            # nans
            # TODO: reenable. Currently all of the tests fail
            # traced = torch.jit.trace(torch_fn, (ins, ins))
            # x = warmup_and_run_forward(traced, rand_a, rand_b)
            # y = torch_fn(nans, rand_b)
            # try:
            #     np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())
            #     print("Succeeded on dev=", dev, "function=", torch_fn)
            # except AssertionError:
            #     # Print extra info before exiting:
            #     print("Failed on dev=", dev, "function=", torch_fn)
            #     # np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())


    def test_round_2(self):
        def round(x):
            return torch.round(x)

        for data_type in [torch.float32, torch.double]:
            a = torch.tensor([0.2, 1.6, 2.5, 3.5]).to(data_type)
            traced = torch.jit.trace(round, (a))
            x = warmup_and_run_forward(traced, a)
            self.assertLastGraphAllFused()
            y = round(x)
            self.assertEqual(x, y)

    def test_rand_like(self):
        N = 1 << 16

        def run_rand_like(x, y):
            return torch.rand_like(torch.add(x, y))

        for device in self.devices:
            x = torch.rand(N, device=device)
            traced = torch.jit.trace(run_rand_like, (x, x), check_trace=False)

            for data_type in self.dtypes:
                _x = x.to(dtype=data_type)
                x_v = warmup_and_run_forward(traced, _x, _x)
                self.assertLastGraphAllFused()

            x_np = x.cpu().numpy()
            x1_mean = np.mean(x_np)
            x2_mean = np.mean(x_np ** 2)
            x3_mean = np.mean(x_np ** 3)
            np.testing.assert_allclose(x1_mean, 1. / 2, rtol=2e-2)
            np.testing.assert_allclose(x2_mean, 1. / 3, rtol=2e-2)
            np.testing.assert_allclose(x3_mean, 1. / 4, rtol=2e-2)

    def test_nans(self):
        def test_max(x, y):
            return torch.max(2 * x, 2 * y)

        def test_min(x, y):
            return torch.min(2 * x, 2 * y)

        tmax = torch.jit.trace(test_max, (torch.rand(1), torch.rand(1)))
        tmin = torch.jit.trace(test_min, (torch.rand(1), torch.rand(1)))

        for data_type in self.dtypes:
            x = torch.tensor([np.nan]).to(dtype=data_type)
            y = torch.tensor([1.0]).to(dtype=data_type)

        assert np.isnan(warmup_and_run_forward(tmin, x, y).float().item())
        assert np.isnan(warmup_and_run_forward(tmin, y, x).float().item())
        self.assertLastGraphAllFused()
        assert np.isnan(warmup_and_run_forward(tmax, x, y).float().item())
        assert np.isnan(warmup_and_run_forward(tmax, y, x).float().item())
        self.assertLastGraphAllFused()

    def test_double_intrinsics(self):
        def do_pow(x):
            return torch.pow(x, 7)

        for device in self.devices:
            x = torch.rand(10, dtype=torch.double, device=device)
            traced = torch.jit.trace(do_pow, (x))
            x = warmup_and_run_forward(traced, x)
            self.assertLastGraphAllFused()

    def test_remainder(self):
        def run_remainder(x, y):
            c = torch.remainder(torch.add(x, y), x)
            return c

        for data_type in self.dtypes:
            a = torch.rand(1024, dtype=data_type)
            b = torch.rand(1024, dtype=data_type)
            zeros = torch.zeros(1024, dtype=data_type)
            cc = np.array(1024, dtype=float)
            cc.fill(np.nan)
            nans = torch.from_numpy(cc).to(dtype=data_type)

            # random floats
            zeros1 = torch.zeros(1024, dtype=data_type)
            zeros2 = torch.zeros(1024, dtype=data_type)

            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
            x = warmup_and_run_forward(traced, a, b)
            self.assertLastGraphAllFused()
            y = run_remainder(a, b)
            if data_type is torch.bfloat16:
                self.assertEqual(x, y, atol=4e-3, rtol=2e-3)
            else:
                self.assertEqual(x, y)

            # div by 0
            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
            x = warmup_and_run_forward(traced, zeros, a)
            self.assertLastGraphAllFused()
            y = run_remainder(zeros, a)
            self.assertEqual(x, y)

            # numerators and denominatos are nan
            traced = torch.jit.trace(run_remainder, (zeros1, zeros2))
            x = warmup_and_run_forward(traced, nans, a)
            self.assertLastGraphAllFused()
            y = run_remainder(nans, a)
            self.assertEqual(x, y)

    def test_multioutput(self):
        def easy(x):
            b = x + 1
            c = b + b
            return (b, c)

        traced = torch.jit.trace(easy, (torch.zeros(1024)))

        a = torch.zeros(1024)
        b, c = warmup_and_run_forward(traced, a)
        self.assertLastGraphAllFused()
        bp = a.numpy() + 1
        cp = bp + bp
        np.testing.assert_allclose(b.numpy(), bp)
        np.testing.assert_allclose(c.numpy(), cp)

    def test_chunk(self):
        def easy(x):
            y = x + 1
            aaa, bbb = torch.chunk(y, 2)
            return aaa + bbb

        for data_type in self.dtypes:
            trace_input = torch.zeros(1024, 1024, dtype=data_type)
            traced = torch.jit.trace(easy, (trace_input))

            a = torch.zeros(32, 32, dtype=data_type)
            x = warmup_and_run_forward(traced, a)
            self.assertLastGraphAllFused()
            npr = a.float().numpy()
            npr2 = npr + 1
            npr_a, npr_b = np.array_split(npr2, 2)
            np.testing.assert_allclose(npr_a + npr_b, x.float().numpy())

    def test_cat(self):
        for device in self.devices:
            _dim = 1

            def foo(*args):
                args_2 = [v + i for i, v in enumerate(args)]
                v = torch.cat(args_2, dim=_dim)
                return v * v

            for data_type in self.dtypes:
                M = 16
                Ns = [128, 16, 1]
                values = [torch.zeros(M, N, dtype=data_type, device=device) for N in Ns]
                traced = torch.jit.trace(foo, values)

                x = warmup_and_run_forward(traced, *values)
                self.assertLastGraphAllFused()
                ref = foo(*values)
                np.testing.assert_allclose(ref.cpu().float().numpy(), x.cpu().float().numpy())

            # Test channels-last
            for _cur_dim in range(4):
                _dim = _cur_dim
                values = [torch.randn((2, 3, 4, 5), device=device).to(memory_format=torch.channels_last) for _ in range(10)]
                traced = torch.jit.trace(foo, values)

                x = warmup_and_run_forward(traced, *values)
                self.assertLastGraphAllFused()
                ref = foo(*values)
                self.assertEqual(ref, x)

    # This test checks that we correctly handle fusion group with just aten::cat in it.
    # Note that the test only makes sense with min_fusion_group=1, otherwise no
    # fusion groups would be formed at all.
    # TODO: Fix and re-enable the test.
    @unittest.skip("cat is broken with fusion group inlining disabled")
    def test_cat_only(self):
        for device in self.devices:
            def foo(*args):
                args_2 = [v + i for i, v in enumerate(args)]
                v = torch.cat(args_2, dim=1)
                return v

            M = 16
            Ns = [128, 16, 1]
            values = [torch.zeros(M, N, device=device) for N in Ns]
            traced = torch.jit.trace(foo, values)

            x = warmup_and_run_forward(traced, *values)
            self.assertLastGraphAllFused()
            ref = foo(*values)
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())

    def test_cat_negative_dim(self):
        for device in self.devices:
            def foo(*args):
                v = torch.cat(args, dim=-1)
                return v * v

            M = 16
            Ns = [128, 16, 1]
            values = [torch.randn(M, N, device=device) for N in Ns]
            traced = torch.jit.trace(foo, values)

            x = warmup_and_run_forward(traced, *values)
            self.assertLastGraphAllFused()
            ref = foo(*values)
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())

    def test_cat_promote_inputs(self):
        for device in self.devices:
            def foo(*args):
                v = torch.cat(args, dim=1)
                return v * v

            M = 16
            Ns = [128, 16, 1]
            dtypes = [torch.half, torch.float32, torch.double]
            values = [torch.randn(M, N, device=device, dtype=dt) for N, dt in zip(Ns, dtypes)]
            traced = torch.jit.trace(foo, values)

            x = warmup_and_run_forward(traced, *values)
            self.assertLastGraphAllFused()
            ref = foo(*values)
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())

    def test_cat_empty_tensors(self):
        for device in self.devices:
            def foo(*args):
                v = torch.cat(args, dim=1)
                return v * v

            M = 16
            Ns = [128, 16, 1]
            empty = torch.tensor([], device=device, dtype=torch.double)
            values = [empty] + [torch.randn(M, N, device=device) for N in Ns]
            traced = torch.jit.trace(foo, values)

            x = warmup_and_run_forward(traced, *values)
            self.assertLastGraphAllFused()
            ref = foo(*values)
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())

            # now test with only empty tensors
            values = [empty for i in range(3)]
            traced = torch.jit.trace(foo, values)
            x = warmup_and_run_forward(traced, *values)
            self.assertLastGraphAllFused()
            ref = foo(*values)
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())

    def test_cat_with_constant_dim(self):
        for device in self.devices:
            def foo(*args):
                v1 = torch.cat(args, dim=1)
                v2 = torch.cat([v1], dim=1)
                return v2 * v2

            empty = torch.tensor([], device=device, dtype=torch.float32)
            inputs = [empty] + [torch.randn(1, 64, device=device), torch.randn(1, 64, device=device)]
            traced = torch.jit.trace(foo, inputs)

            x = warmup_and_run_forward(traced, *inputs)
            self.assertLastGraphAllFused()
            ref = foo(*inputs)
            np.testing.assert_allclose(ref.cpu().numpy(), x.cpu().numpy())

    def test_scalar(self):
        @torch.jit.script
        def test_float(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: float, b: float) -> torch.Tensor:
            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)

        @torch.jit.script
        def test_int(x: torch.Tensor, y: torch.Tensor, z: torch.Tensor, a: int, b: int) -> torch.Tensor:
            return torch.add(torch.add(x, y, alpha=a), z, alpha=b)

        for test in (test_float, test_int):
            for data_type in self.dtypes:
                x, y, z = (torch.rand(4, dtype=data_type) for i in range(3))
                a, b = 1, 2
                test(x, y, z, a, b)
                r = test(x, y, z, a, b)
                self.assertEqual(r, x + y * a + z * b)

    def test_loop(self):
        @torch.jit.script
        def test(x: torch.Tensor, y: torch.Tensor, z: int) -> torch.Tensor:
            b = y
            for i in range(0, z):
                a = x + y
                b = b + y
            return b

        x, y, z = (torch.zeros(32, 32), torch.ones(32, 32), 4)
        test(x, y, z)
        r = test(x, y, z)

    def test_slice(self):
        def easy(x, y):
            a = x[0:512:2]
            b = y[0:512:2]
            return a + b

        traced = torch.jit.trace(easy, (torch.ones(1024, 1024), torch.zeros(1024, 1024)))

        a = torch.ones(1024, 1024)
        x = traced(a, a)
        npr = a[0:512:2]
        npr = npr + npr
        np.testing.assert_allclose(npr.numpy(), x.numpy())

    def test_unsqueeze(self, N=256):
        def easy(x, y):
            a = torch.unsqueeze(x, 0)
            b = torch.unsqueeze(y, 0)
            return a + b

        traced = torch.jit.trace(easy, (torch.ones(N, N), torch.zeros(N, N)))

        a = torch.rand(N, N)
        x = traced(a, a)
        npr = np.expand_dims(a, 0)
        npr = npr + npr
        np.testing.assert_allclose(npr, x.numpy())

    def _test_softmax(self, device):
        def test_softmax(x, y):
            a = F.softmax(x, dim=0, dtype=torch.float32)
            b = F.softmax(y, dim=0, dtype=torch.float32)
            c = F.softmax(x, dim=1, dtype=torch.float32)
            d = F.softmax(y, dim=1, dtype=torch.float32)
            return a + b + c + d

        def test_softmax_neg_index(x, y):
            a = F.softmax(x, dim=-2, dtype=torch.float32)
            b = F.softmax(y, dim=-2, dtype=torch.float32)
            c = F.softmax(x, dim=-1, dtype=torch.float32)
            d = F.softmax(y, dim=-1, dtype=torch.float32)
            return a + b + c + d

        def test_log_softmax(x, y):
            a = F.log_softmax(x, dim=0, dtype=torch.float32)
            b = F.log_softmax(y, dim=0, dtype=torch.float32)
            c = F.log_softmax(x, dim=1, dtype=torch.float32)
            d = F.log_softmax(y, dim=1, dtype=torch.float32)
            return a + b + c + d

        for test in (test_softmax, test_log_softmax, test_softmax_neg_index):
            for data_type in self.dtypes:
                old = torch._C._jit_set_texpr_reductions_enabled(True)
                traced_input = torch.randn(2, 3, dtype=data_type, device=device)
                traced = torch.jit.trace(test, (traced_input, traced_input))
                inp = torch.randn(2, 3, dtype=data_type, device=device)
                res = traced(inp, inp)
                # Use eager mode as reference.
                ref = test(inp, inp)
                np.testing.assert_allclose(ref, res.cpu().numpy(), rtol=1e-06, atol=1e-06)
                torch._C._jit_set_texpr_reductions_enabled(old)

    def test_softmax_cpu(self):
        self._test_softmax('cpu')

    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
    @unittest.skip("global allocs are not supported yet.")
    def test_softmax_cuda(self):
        self._test_softmax('cuda')

    def test_half_gelu(self):
        devices = ["cuda"] if torch.cuda.is_available() else []

        @torch.jit.script
        def bias_gelu(bias, y):
            x = bias + y
            return x * 0.5 * (1.0 + torch.erf(x / 1.41421))

        for device in devices:
            a = torch.rand(1024, dtype=torch.half, device=device)
            b = torch.rand(1024, dtype=torch.half, device=device)
            traced = torch.jit.trace(bias_gelu, (a, b))
            x = warmup_and_run_forward(traced, a, b)
            self.assertLastGraphAllFused()

    def test_half_bn_relu(self):
        devices = ["cuda"] if torch.cuda.is_available() else []

        def foo(a, b, c):
            y = torch.nn.functional.batch_norm(a, b, c)
            z = y.relu()
            return z

        for device in devices:
            a = torch.rand(16, 16, dtype=torch.half, device=device)
            b = torch.rand(16, dtype=torch.half, device=device)
            c = torch.rand(16, dtype=torch.half, device=device)
            traced = torch.jit.trace(foo, (a, b, c))
            print(traced.graph)
            x = warmup_and_run_forward(traced, a, b, c)
            self.assertLastGraphAllFused()

    def test_exp_pow(self):
        @torch.jit.script
        def do_exp(x, y, z):
            return ((x * y) * 2) * torch.pow(z, 2)

        for device in self.devices:
            x = torch.rand(10, dtype=torch.double, device=device)
            y = torch.rand(10, dtype=torch.double, device=device)
            z = torch.rand(10, dtype=torch.double, device=device)
            traced = torch.jit.trace(do_exp, (x, y, z))
            x = warmup_and_run_forward(traced, x, y, z)
            self.assertLastGraphAllFused()

    def test_sin_pow(self):
        def test(x):
            return torch.sin(torch.pow(x, 0))

        for data_type, shape in itertools.product(self.dtypes, [[3], [5], [10]]):
            x = torch.rand(shape, dtype=data_type)
            scripted = torch.jit.script(test)
            out = warmup_and_run_forward(scripted, x)
            self.assertLastGraphAllFused()
            self.assertEqual(out, test(x))

    def test_transpose(self):
        @torch.jit.script
        def test(x, y, z):
            return x.transpose(0, 1) + y + z
        x = torch.rand(4, 5, 2, 3)
        y = torch.rand(5, 4, 2, 3)
        z = torch.rand(5, 4, 2, 3)
        ref = test(x, y, z)
        res = test(x, y, z)
        np.testing.assert_allclose(ref.numpy(), res.numpy())

    def test_sliced_stride(self):
        @torch.jit.script
        def test(x, y, z):
            return x + y + z
        x = torch.rand(16, 4, 2, 3)[::2]
        y = torch.rand(8, 4, 2, 3)
        z = torch.rand(8, 4, 2, 3)
        ref = test(x, y, z)
        res = test(x, y, z)
        np.testing.assert_allclose(ref.numpy(), res.numpy())

    @unittest.skip("dynamic shapes are not quite there yet")
    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
    def test_dynamic_shape(self):
        with num_profiled_runs(2):
            @torch.jit.script
            def test(x, y, z):
                return x * y * z
            x, y, z = (torch.rand(4, 8).cuda() for _ in range(3))
            ref = test(x, y, z)
            _ = test(*[torch.rand(6, 8).cuda() for _ in range(3)])
            res = test(x, y, z)
            np.testing.assert_allclose(ref.cpu().numpy(), res.cpu().numpy())

            # A wild broadcast appears.
            x = torch.rand(4, 8).cuda()
            y = torch.rand(1, 8).cuda()
            z = torch.rand(4, 1).cuda()
            res = test(x, y, z)
            xn, yn, zn = (t.cpu().numpy() for t in (x, y, z))
            np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)

            # Mismatched shapes shouldn't reach codegen.
            x = torch.rand(4, 8).cuda()
            y = torch.rand(4, 8).cuda()
            z = torch.rand(5, 8).cuda()
            try:
                res = test(x, y, z)
            except RuntimeError as e:
                assert "The size of tensor a (4) must match" in e.args[0]

            # Changing a static dimension fails guards.
            # x, y, z = [torch.rand(4, 7).cuda() for _ in range(3)]
            # xn, yn, zn = [t.cpu().numpy() for t in (x, y, z)]
            # res = test(x, y, z)
            # print(test.graph_for(x, y, z))
            # np.testing.assert_allclose(res.cpu().numpy(), xn * yn * zn)

    @unittest.skipIf(not torch.cuda.is_available(), "requires CUDA")
    def test_guard_fails(self):
        @torch.jit.script
        def test(x, y, z):
            return x * y * z
        r1 = test(*[torch.rand(4).cuda() for _ in range(3)])
        r2 = test(*[torch.rand(4).cuda() for _ in range(3)])
        r3 = test(*[torch.rand(4).cuda() for _ in range(3)])
        r4 = test(*[torch.rand(7).cuda() for _ in range(3)])

    def test_bitwise_ops(self):
        def run_and(x, y):
            return x & (x & y)

        def run_or(x, y):
            return x & (x | y)

        def run_xor(x, y):
            return x ^ (x ^ y)

        def run_lshift(x, y):
            return x & (x << y)

        def run_rshift(x, y):
            return x & (x >> y)

        fns = {run_and, run_or, run_xor, run_lshift, run_rshift}

        for device in self.devices:
            for fn in fns:
                a = torch.ones(128, dtype=torch.int32, device=device)
                b = torch.zeros(128, dtype=torch.int32, device=device)
                inp = torch.ones(128, dtype=torch.int32, device=device)
                traced = torch.jit.trace(fn, (inp, inp))
                x = warmup_and_run_forward(traced, a, b)
                self.assertLastGraphAllFused()
                y = fn(a, b)
                np.testing.assert_allclose(x.cpu().numpy(), y.cpu().numpy())

    def test_where(self):
        def run_where(x, y):
            return torch.where(torch.gt(x, y), x, y)

        for data_type in self.dtypes:
            a = torch.rand(1024, dtype=data_type)
            b = torch.rand(1024, dtype=data_type)
            zeros = torch.zeros(1024, dtype=data_type)
            traced = torch.jit.trace(run_where, (zeros, zeros))
            x = warmup_and_run_forward(traced, a, b)
            self.assertLastGraphAllFused()
            y = run_where(a, b)
            np.testing.assert_allclose(x.float().numpy(), y.float().numpy())

    def test_multi_rand(self):
        for device in self.devices:
            def test(x):
                y = torch.rand_like(x)
                return (x + y) - (y - x)

            _atol = 2e-3
            _rtol = 1e-5
            for data_type in self.dtypes:
                if data_type is torch.bfloat16:
                    _atol = 2e-2
                a = torch.rand(4, dtype=data_type, device=device)
                scripted = torch.jit.script(test)
                out = warmup_and_run_forward(scripted, a)
                self.assertLastGraphAllFused()
                assert torch.allclose(out, 2 * a, atol=_atol, rtol=_rtol)

    def test_mask(self):
        def test(x):
            return x.unsqueeze(1) == 0

        for d in self.devices:
            for data_type in self.dtypes:
                x = torch.rand(4, dtype=data_type, device=d) > 0.5
                scripted = torch.jit.script(test)
                out = warmup_and_run_forward(scripted, x)
                self.assertLastGraphAllFused()
                assert torch.equal(out, test(x))

    def test_simple_add(self):
        val = torch._C._jit_get_te_generate_block_code()
        torch._C._jit_set_te_generate_block_code(True)
        fall_bk = torch._C._jit_texpr_fallback_allowed()
        torch._C._jit_texpr_set_fallback_allowed(True)

        def simple(a, b):
            return torch.add(a, b)

        a = torch.ones(256, 256)
        b = torch.ones(256, 256)
        traced = torch.jit.trace(simple,
                                 (torch.ones(256, 256), torch.ones(256, 256)))
        f = traced(a, b)
        f_test = np.full((256, 256), 2, dtype=float)
        np.testing.assert_allclose(f.numpy(), f_test)
        torch._C._jit_set_te_generate_block_code(val)
        torch._C._jit_texpr_set_fallback_allowed(fall_bk)

    def test_strided_output_preserved(self):
        def foo(a, b):
            return a + b - a

        # smaller, easier to debug example
        x = torch.arange(6)
        x = torch.as_strided(x, (2, 3), (1, 2))
        total = 0
        for i in range(2):
            for j in range(3):
                x[i, j] = total
                total += 1
        foo_script = torch.jit.script(foo)
        foo_script(x, x)
        foo_script(x, x)
        out_s = foo_script(x, x)
        out_eager = foo(x, x)
        self.assertEqual(out_s, out_eager)
        self.assertEqual(out_s.stride(), out_eager.stride())
        self.assertLastGraphAllFused()

        # more dims
        N, C, H, W, = 2, 3, 4, 5
        x = torch.rand(N, C, H, W).to(memory_format=torch.channels_last)
        foo_script = torch.jit.script(foo)
        foo_script(x, x)
        foo_script(x, x)
        out_s = foo_script(x, x)
        out_eager = foo(x, x)
        self.assertEqual(out_s, out_eager)
        self.assertEqual(out_s.stride(), out_eager.stride())
        self.assertLastGraphAllFused()

    def test_alias_analysis_module(self):
        class AliasModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                torch.manual_seed(1337)
                self.a = torch.randn(128, 128)
                self.b = torch.randn(128, 128)
                self.c = torch.randn(128, 128)

            def forward(self, x, y, z):
                z = z + self.a
                self.b.add_(y)
                w = z + self.a
                z = w + x
                return z
        x = torch.randn(128, 128)

        def getModule(script):
            am = AliasModule()
            if script:
                return torch.jit.script(am)
            return am

        am = getModule(False)
        am_s = getModule(True)
        ref = am(x, x, x)
        test = am_s(x, x, x)
        torch.testing.assert_close(ref, test)

        # Now do the aliasing
        am.a = am.b
        ref = am(x, x, x)

        am_s.a = am_s.b
        test = am_s(x, x, x)

        torch.testing.assert_close(ref, test)

    def test_alias_analysis_inputs(self):
        class AliasModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                torch.manual_seed(1337)
                self.a = torch.randn(128, 128)
                self.b = torch.randn(128, 128)
                self.c = torch.randn(128, 128)

            def forward(self, x, y, z):
                x.add_(y)
                w = z + self.a
                z = w + x
                return z

        def getModule(script):
            am = AliasModule()
            if script:
                return torch.jit.script(am)
            return am
        am = getModule(False)
        am_s = getModule(True)

        torch.manual_seed(1337)
        x = torch.randn(128, 128)
        ref = am(x, x, x)

        torch.manual_seed(1337)
        x = torch.randn(128, 128)
        test = am_s(x, x, x)

        torch.testing.assert_close(ref, test)

    def test_alias_analysis_input_and_module(self):
        class AliasModule(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                torch.manual_seed(1337)
                self.a = torch.randn(128, 128)
                self.b = torch.randn(128, 128)
                self.c = torch.randn(128, 128)

            def forward(self, x, y, z):
                x.add_(y)
                w = z + self.b
                z = w + x
                return z

        def getModule(script):
            am = AliasModule()
            if script:
                return torch.jit.script(am)
            return am
        am = getModule(False)
        am_s = getModule(True)

        torch.manual_seed(1337)
        x = torch.randn(128, 128)
        am.b = x
        ref = am(x, x, x)

        torch.manual_seed(1337)
        x = torch.randn(128, 128)
        am_s.b = x
        test = am_s(x, x, x)

        torch.testing.assert_close(ref, test)

    def test_multiple_outputs(self):
        for device in self.devices:
            # A bug reported internally similar to the one reported in #48533
            def foo(a, b, c):
                t_next = c + 1
                t5 = t_next * b
                t6 = torch.unsqueeze(t_next, 1)
                t7 = a * t6
                return (t7, t5, t_next)

            for data_type in self.dtypes:
                a = torch.rand(20, 20, dtype=data_type, device=device)
                b = torch.rand(20 * 29, dtype=data_type, device=device).as_strided([20], [29])
                c = torch.ones(20, dtype=torch.int64, device=device)
                traced = torch.jit.trace(foo, (a, b, c))
                ref = foo(a, b, c)
                exp = traced(a, b, c)
                exp = traced(a, b, c)
                self.assertEqual(ref, exp)

    def test_propagated_mem_layout(self):
        def foo(a, b, c):
            t_next = c + 1
            t5 = t_next * b
            t7 = a * t5
            return t7

        def foo_multi_outputs(a, b, c):
            t_next = c + 1
            t5 = b * t_next
            t7 = a * t5
            return (t7, t5, t_next)

        def foo_multi_outputs_i_nhwc_o_nchw(a, b, c):
            t_next = c + 1
            t5 = b * t_next
            t7 = a * t5
            t8 = t7.to(memory_format=torch.contiguous_format)
            return (t8, t7, t5, t_next)

        def run_foo_case(foo, a, b, c):
            traced_contiguous = torch.jit.trace(foo, (a, b, c))
            ref = foo(a, b, c)
            exp = traced_contiguous(a, b, c)
            exp = traced_contiguous(a, b, c)
            self.assertEqual(ref, exp)

        mem_layouts = list(itertools.product([torch.contiguous_format, torch.channels_last], repeat=3))
        shapes = [(2, 3, 4, 5), (2, 1, 1, 5), (1, 1, 1, 1)]
        permutes = [(0, 3, 2, 1), (0, 3, 1, 2)]
        funcs = [foo, foo_multi_outputs, foo_multi_outputs_i_nhwc_o_nchw]
        configs = itertools.product(funcs, shapes, mem_layouts, permutes)
        for strategy in ["STATIC", "DYNAMIC"]:
            old_strategy = torch.jit.set_fusion_strategy([(strategy, 10)])
            for _func, _shape, _mem_layouts, _permute in configs:
                a = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[0])
                b = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[1])
                c = torch.rand(_shape, dtype=torch.float32).to(memory_format=_mem_layouts[2])
                run_foo_case(_func, a, b, c)

                a = a.permute(dims=_permute)
                b = b.permute(dims=_permute)
                c = c.permute(dims=_permute)
                run_foo_case(_func, a, b, c)

            torch.jit.set_fusion_strategy(old_strategy)

if __name__ == '__main__':
    run_tests()
