# Owner(s): ["module: decompositions"]

import functools
import itertools
import re
import unittest
from collections import defaultdict
from functools import partial

import torch._inductor.decomposition
import torch.autograd
from torch import Tensor
from torch._decomp import core_aten_decompositions, decomposition_table
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import DispatchKey
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    onlyCPU,
    onlyCUDA,
    onlyNativeDeviceTypes,
    ops,
)
from torch.testing._internal.common_methods_invocations import (
    op_db,
    skip,
    skipOps,
    xfail,
)
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
    is_iterable_of_tensors,
    run_tests,
    skipIfCrossRef,
    skipIfTorchDynamo,
    suppress_warnings,
    TEST_WITH_ASAN,
    TEST_WITH_SLOW,
    TestCase,
    unMarkDynamoStrictTest,
)
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten


aten = torch.ops.aten


# TODO: this isn't going to work with non-aten namespaces
def overload_to_aten_name(op):
    return op._schema.name.split("::")[1]


# All operators that can have decomp tests
decomposition_names = {
    overload_to_aten_name(k)
    for k in decomposition_table
    if isinstance(k, torch._ops.OpOverload)
}
core_decomposition_names = {
    overload_to_aten_name(k)
    for k in core_aten_decompositions()
    if isinstance(k, torch._ops.OpOverload)
}
_decomp_test_ops = [
    op
    for op in op_db
    if op.aten_name in decomposition_names
    or op.aten_backward_name in decomposition_names
]
_decomp_test_ops_core_autograd = [
    op
    for op in op_db
    if op.aten_name in core_decomposition_names and op.supports_autograd
]
_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]


def diff_arg(arg, requires_grad=True):
    def is_differentiable_arg(arg):
        if requires_grad:
            return arg.requires_grad
        else:
            return arg.is_floating_point() or arg.is_complex()

    if is_iterable_of_tensors(arg):
        if all(is_differentiable_arg(a) for a in arg):
            return True
        if all(not is_differentiable_arg(a) for a in arg):
            return False
        raise RuntimeError("NYI: The test runner can't handle this")
    return isinstance(arg, Tensor) and is_differentiable_arg(arg)


# Version of autograd.grad with some differences:
#   - pytree inputs is allowed (but leaves of the pytree have to all
#     be tensors)
#   - if an input is not used as part of derivatives, we will return a
#     zero-filled tensor for the result
def _autograd_grad(
    outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
    inputs, inputs_spec = tree_flatten(inputs)
    diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
    if grad_outputs is None:
        diff_outputs = tuple(out for out in outputs if out.requires_grad)
    else:
        diff_grad_outputs = [
            (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
        ]
        if len(diff_grad_outputs) == 0:
            diff_outputs, grad_outputs = (), ()
        else:
            diff_outputs, grad_outputs = zip(*diff_grad_outputs)
    grad_inputs = torch.autograd.grad(
        diff_outputs,
        diff_inputs,
        grad_outputs,
        retain_graph=retain_graph,
        create_graph=create_graph,
        allow_unused=True,
    )
    result = []
    grad_inputs_iter = iter(grad_inputs)
    for inp in inputs:
        if inp.requires_grad:
            grad_input = next(grad_inputs_iter)
            if grad_input is None:
                result.append(torch.zeros_like(inp))
            else:
                result.append(grad_input)
        else:
            result.append(torch.zeros_like(inp))
    return tree_unflatten(result, inputs_spec)


def _as_tuple(val):
    if isinstance(val, tuple):
        return val
    return (val,)


def ref_vjp_no_create(f, *primals):
    result = f(*primals)

    def wrapped(cotangents):
        return _autograd_grad(
            _as_tuple(result),
            primals,
            _as_tuple(cotangents),
            create_graph=False,
            retain_graph=True,
        )

    return result, wrapped


dtype_precisions = {
    torch.float16: (0.001, 1e-5),
    torch.bfloat16: (0.016, 1e-4),
    torch.float32: (1.3e-6, 1e-5),
    torch.float64: (1e-7, 1e-7),
    torch.complex32: (0.001, 1e-5),
    torch.complex64: (1.3e-6, 1e-5),
    torch.complex128: (1e-7, 1e-7),
}
# Returns the "default" rtol and atol for comparing scalars or
# tensors of the given dtypes.


def _getDefaultRtolAndAtol(dtype0, dtype1):
    rtol = max(
        dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0]
    )
    atol = max(
        dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1]
    )
    return rtol, atol


def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs):
    assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}"
    if orig.numel() == 0 or decomp.numel() == 0:
        assert orig.numel() == decomp.numel()
        return
    assert orig.shape == decomp.shape, f"{i} Operation:  {op}"
    tol_table = {
        (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5,
        (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5,
        (torch.float16, torch.ops.aten.native_layer_norm_backward.default): 1e-3,
        (torch.bfloat16, torch.ops.aten.native_layer_norm_backward.default): 2e-2,
        (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5,
        (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5,
        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
        (torch.bfloat16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
        (torch.float16, torch.ops.aten._native_batch_norm_legit.default): 1e-5,
        (torch.float16, torch.ops.aten._native_batch_norm_legit.no_stats): 1e-5,
        (torch.bfloat16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
        (torch.float16, torch.ops.aten.linalg_vector_norm.default): 1e-4,
        (torch.bfloat16, torch.ops.aten.var_mean.correction): 5e-7,
        (torch.float16, torch.ops.aten.var_mean.correction): 5e-7,
        (torch.bfloat16, torch.ops.aten.var_mean.dim): 5e-7,
        (torch.float16, torch.ops.aten.var_mean.dim): 5e-7,
        (torch.float16, torch.ops.aten.nll_loss_forward.default): 1e-2,
        (torch.bfloat16, torch.ops.aten.nll_loss_forward.default): 1e-1,
        (torch.float16, torch.ops.aten.nll_loss2d_forward.default): 1e-2,
        (torch.bfloat16, torch.ops.aten.nll_loss2d_forward.default): 2e-1,
        (torch.float16, torch.ops.aten.hardswish.default): 2e-7,
        (torch.bfloat16, torch.ops.aten.hardswish.default): 2e-7,
        (torch.float16, torch.ops.aten.multi_margin_loss.default): 3e-2,
        (torch.bfloat16, torch.ops.aten.multi_margin_loss.default): 5e-2,
        (torch.float16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
        (torch.bfloat16, torch.ops.aten.multilabel_margin_loss_forward.default): 3e-2,
        (torch.float16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
        (torch.bfloat16, torch.ops.aten.reflection_pad1d_backward.default): 5e-3,
        (torch.float16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
        (torch.bfloat16, torch.ops.aten.reflection_pad2d_backward.default): 5e-3,
        (torch.float16, torch.ops.aten.reflection_pad3d_backward.default): 5e-3,
        (torch.bfloat16, torch.ops.aten.reflection_pad3d_backward.default): 5e-2,
        # see https://github.com/pytorch/pytorch/pull/96264
        (torch.float16, torch.ops.aten.mv.default): 1e-5,
        (torch.bfloat16, torch.ops.aten.mv.default): 1e-5,
        (torch.float16, torch.ops.aten.log_sigmoid_backward.default): 2e-5,
        (torch.float16, torch.ops.aten._softmax_backward_data.default): 3e-7,
    }
    if ref.is_floating_point():
        orig_diff = (orig - ref).abs().max()
        decomp_diff = (decomp - ref).abs().max()
        atol = tol_table.get((test_dtype, op), 1e-7)
        if decomp_diff > orig_diff + atol:
            raise RuntimeError(
                f"Difference from float64 is larger with decomposition {op.__name__}"
                f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n"
                f"atol = {atol}\n"
                f"args = {args}\n"
                f"kwargs = {kwargs}"
            )
    else:
        test_case.assertEqual(
            orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}"
        )


def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
    test_case.assertEqual(
        orig.dtype,
        decomp.dtype,
        f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
    )
    # Before adding an entry to this table, make sure your decomposition is right :)
    tol_table = {
        # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
        (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3),
        (torch.float32, torch.ops.aten.native_layer_norm_backward.default): (
            1e-3,
            1e-3,
        ),
        (torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
        # This exceeds default tolerances only on CPU, on CUDA it's fine
        (torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
        # Exceeds tolerances on CUDA, likely due to fma
        (torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
        (torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
        (torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
        (torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
        # The decomposition is TOO correct. It computes everything in int64, so sometimes
        # there's an off-by-one error. See
        # https://github.com/pytorch/pytorch/issues/81996
        # https://github.com/pytorch/pytorch/issues/82230
        (torch.int8, torch.ops.aten.linspace.default): (0, 1),
        (torch.uint8, torch.ops.aten.linspace.default): (0, 1),
        (torch.int16, torch.ops.aten.linspace.default): (0, 1),
        (torch.int32, torch.ops.aten.linspace.default): (0, 1),
        (torch.int64, torch.ops.aten.linspace.default): (0, 1),
        (torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
        (torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
        (torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
        (torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
        (torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
        (torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
        (torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
        (torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
        (torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
        (torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
        (torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
        (torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
        (torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
        (torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
        (torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
    }
    if (decomp.dtype, op) in tol_table:
        rtol, atol = tol_table[(decomp.dtype, op)]
    else:
        rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
    test_case.assertEqual(
        orig,
        decomp,
        rtol=rtol,
        atol=atol,
        msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
    )


# Given f, returns an f' such that:
# - f' takes only positional arguments
# - All arguments to f' are floating-point Tensors
# - All outputs of f' are floating-point Tensors
def normalize_op_input_output2(
    f, args, kwargs, output_process_fn_grad=None, requires_grad=True
):
    flat_args, args_spec = tree_flatten(args)
    diff_argnums = tuple(
        i
        for i, arg in enumerate(flat_args)
        if diff_arg(arg, requires_grad=requires_grad)
    )
    assert len(diff_argnums) > 0
    primals = tuple(flat_args[i] for i in diff_argnums)

    @functools.wraps(f)
    def wrapped(*primals):
        _args = list(flat_args)
        for num, arg in zip(diff_argnums, primals):
            _args[num] = arg
        _args = tree_unflatten(_args, args_spec)
        result = f(*_args, **kwargs)
        if output_process_fn_grad is not None:
            result = output_process_fn_grad(result)
        if isinstance(result, tuple):
            # TODO We should check that the integer outputs also agree
            result = tuple(
                r
                for r in result
                if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex())
            )
            assert len(result) > 0
        return result

    return wrapped, primals


# NB: This also upcasts dtype arguments
# TODO: handle complex correctly
def upcast_tensor(x, dtype=torch.float32):
    if isinstance(x, Tensor) and x.dtype.is_floating_point:
        return x.to(dtype=dtype)
    elif isinstance(x, torch.dtype) and x in [
        torch.float16,
        torch.bfloat16,
        torch.float,
    ]:
        return dtype
    else:
        return x


def normalize_op_input_output(f, sample, requires_grad=True):
    args = tuple([sample.input] + list(sample.args))
    return normalize_op_input_output2(
        f,
        args,
        sample.kwargs,
        sample.output_process_fn_grad,
        requires_grad=requires_grad,
    )


CROSS_REF_EXCLUDE_SET = {
    # CUBLAS_STATUS_NOT_SUPPORTED when calling
    # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k,
    # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF,
    # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec,
    # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`
    ("cuda", torch.bfloat16, "nn.functional.bilinear"),
    # randomness
    (None, None, "special.ndtr"),  # aten.special_ndtr was not decomposed
    (None, None, "new_empty"),
    (None, None, "empty_like"),
    (None, None, "empty"),
    # AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
    (None, None, "item"),
    # It's the only in-place op without an out-of-place equivalent in the Python API
    # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
    (None, None, "zero_"),
    # No idea what's going on here
    # In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
    # in the test, but it seems to pass when tested locally and in the logsumexp test
    (None, torch.float32, "masked.logsumexp"),
    (None, torch.float64, "masked.logsumexp"),
    # exp_vml_cpu not implemented for Half
    (torch.cpu, torch.float16, "signal.windows.exponential"),
    (torch.cpu, torch.float16, "signal.windows.gaussian"),
    # sin_vml_cpu not implemented for Half
    (torch.cpu, torch.float16, "signal.windows.cosine"),
    # CompositeAutogradImplicit
    # See https://github.com/pytorch/pytorch/issues/81669
    (None, None, "nn.functional.relu6"),
    # This decomp runs before autograd.
    (None, None, "nn.functional.rrelu"),
    (None, None, "meshgrid"),
    # Decomposition registered as Autograd
    (None, None, "nn.functional.hardshrink"),
    (None, None, "nn.functional.softshrink"),
    # diag was not decomposed (it just registers a decomp for diag_out, torch.diag is CompImplicit)
    (None, None, "diag"),
    # _softmax_backward_data's CPU kernel for bfloat16 always return the grad_input as float32
    ("cpu", torch.bfloat16, "_softmax_backward_data"),
    (None, None, "norm"),
    # native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
    (None, None, "native_batch_norm"),
    (None, None, "_upsample_bilinear2d_aa"),
    (None, None, "empty_strided"),  # aten.empty_strided was not decomposed
}

CROSS_REF_BACKWARD_EXCLUDE_SET = {
    # Decomposed backward formula is not as precise
    ("cpu", torch.bfloat16, "nn.functional.hardswish"),
    ("cuda", torch.float16, "nn.functional.cross_entropy"),
}

all_decomposed = set()
all_called = defaultdict(int)

# Helpful snippet for testing coverage
"""
import atexit
def check_coverage():
    print("missing coverage:")
    print("\n".join(map(str, decomposition_table.keys() - all_decomposed)))
atexit.register(check_coverage)
"""

# Helpful snippet for Horace to create his google sheet :)
"""
import atexit
def dump_ops():
    with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g:
        for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__):
            f.write(f'{op.__name__}\n')
            g.write(f'{count}\n')
    with open('run_decompositions.txt', 'w') as f:
        for op in sorted([i.__name__ for i in all_decomposed]):
            f.write(f'{op}\n')

atexit.register(dump_ops)
"""


def any_unsupported(args, kwargs):
    def test_unsupported(t):
        if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
            # These are all things that we haven't coded decompositions
            # to handle correctly.  Maybe they should.
            return any(
                [
                    t.is_sparse_csr,
                    t.is_sparse,
                    t.is_mkldnn,
                    t.is_quantized,
                    t.is_nested,
                    torch._is_functional_tensor(t),
                ]
            )
        elif torch.overrides.is_tensor_like(t):
            # Decompositions will generally change the behavior of Tensor-like
            # subclasses, so bypass tests in this case too
            return True
        else:
            return False

    flat_args = pytree.arg_tree_leaves(*args, **kwargs)
    return any(test_unsupported(x) for x in flat_args)


core_backward_failures = {
    skip("_softmax_backward_data"),  # slow: fails with --timeout=360 secs
    xfail("addcdiv"),
    skip("addcmul"),  # slow: fails with --timeout=360 secs
    skip("deg2rad"),  # slow: fails with --timeout=360 secs
    skip("diag_embed"),  # slow: fails with --timeout=360 secs
    skip("frac"),  # slow: fails with --timeout=360 secs
    skip("grid_sampler_2d"),  # slow: fails with --timeout=360 secs
    xfail("lerp"),
    skip("logaddexp"),  # slow: fails with --timeout=360 secs
    skip("native_dropout_backward"),  # slow: fails with --timeout=360 secs
    xfail("nn.functional.binary_cross_entropy_with_logits"),
    skip("nn.functional.glu"),  # slow: fails with --timeout=360 secs
    xfail("nn.functional.hardshrink"),
    xfail("nn.functional.softshrink"),
    skip("nn.functional.unfold"),  # slow: fails with --timeout=360 secs
    xfail("norm"),
    xfail("norm", "fro"),
    xfail("norm", "inf"),
    xfail("norm", "nuc"),
    skip("rad2deg"),  # slow: fails with --timeout=360 secs
    skip("renorm"),  # slow: fails with --timeout=360 secs
    skip("rot90"),  # slow: fails with --timeout=360 secs
    skip("rsub"),  # slow: fails with --timeout=360 secs
    skip("sgn"),  # slow: fails with --timeout=360 secs
    skip("special.xlog1py"),  # slow: fails with --timeout=360 secs
    xfail("stack"),
    skip("tril"),  # slow: fails with --timeout=360 secs
    skip("triu"),  # slow: fails with --timeout=360 secs
    skip("unfold_copy"),  # slow: fails with --timeout=360 secs
    skip("xlogy"),  # slow: fails with --timeout=360 secs
    xfail("zero_"),
}
if not TEST_WITH_SLOW:
    core_backward_failures.update(
        {
            skip("addr"),  # slow: takes 46 sec on A100
            skip("baddbmm"),  # slow: takes 800+ sec on A100
            skip("clamp_min"),  # slow: takes 800 sec on A100
            skip("clamp_max"),  # slow: takes 800 sec on A100
            skip("logit"),  # slow: takes 44 sec on A100
            skip("nn.functional.hardswish"),  # slow: takes 60 sec on A100
            skip("std_mean"),  # slow: takes 170 sec on A100
            skip("split", variant_name="list_args"),  # slow: takes 118 sec on A100
            skip("transpose"),  # slow: takes 50 sec on A100
            skip("unbind"),  # slow: takes 70 sec on A100
            skip("unsafe_split"),  # slow: takes 49 sec on A100
        }
    )

comprehensive_failures = {
    xfail(
        "nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
    ),  # off by one error
    xfail(
        "nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
    ),  # off by one error
    xfail(
        "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
    ),  # off by one error
}


@unMarkDynamoStrictTest
class TestDecomp(TestCase):
    longMessage = True

    # NB: This actually overlaps with test_comprehensive, but it only
    # runs on things that are definitely decomposed so it's a lot faster
    # to run
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    @suppress_warnings
    @ops(_decomp_test_ops)
    def test_quick(self, device, dtype, op):
        self.do_cross_ref(device, dtype, op, run_all=False)

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    @suppress_warnings
    @ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,))
    def test_quick_core_backward(self, device, dtype, op):
        for sample_input in op.sample_inputs(device, dtype, requires_grad=True):
            aten_name = op.decomp_aten_name or op.aten_name
            args = [sample_input.input] + list(sample_input.args)
            kwargs = sample_input.kwargs
            func = partial(op.get_op(), **kwargs)
            with self.DecompCrossRefMode(
                self, self.precision, self.rel_tol, dtype, run_all=False
            ) as mode, enable_python_dispatcher():
                torch.autograd.gradcheck(func, args)
            self.check_decomposed(aten_name, mode)

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    @skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
    @suppress_warnings
    @ops(op_db)
    def test_comprehensive(self, device, dtype, op):
        self.do_cross_ref(device, dtype, op, run_all=True)

    def test_uniform(self, device):
        size = (2, 3, 4, 5)
        dtype = torch.float32
        x = make_tensor(size, dtype=dtype, device=device)
        low = 0.3
        high = 0.9

        torch.manual_seed(123)
        ref = torch.ops.aten.uniform(x, low, high)
        torch.manual_seed(123)
        res = torch._decomp.decompositions.uniform(x, low=low, high=high)
        self.assertEqual(ref, res)

    def test_broadcasting_index_copy(self, device):
        x = torch.zeros([1, 10], device=device)
        xs = torch.ones([2, 10], device=device)

        def index_copy(xs, x):
            torch._decomp.decompositions.index_copy_(
                xs, 0, torch.tensor(0).to(device), x
            )

        index_copy(xs, x)

        xs_two = torch.ones([2, 10], device=device)
        xs_two[0] = x

        self.assertEqual(xs, xs_two)

    def test_cat_single_input(self, device):
        decomp_table = torch._inductor.decomposition.select_decomp_table()
        cat_inductor = decomp_table[torch.ops.aten.cat.default]

        inp = torch.rand([2048, 2048], device=device)
        inps = [inp for _ in range(10)]

        for dim in (-1, 0, 1):
            self.assertEqual(torch.cat(inps, dim), cat_inductor(inps, dim))

    def test_rrelu_with_noise(self, device):
        # rrelu_with_noise behavior depends on a) whether elements in the input
        # are <= 0, and b) whether we're in training mode. Cover all cases:
        dtype = torch.float64
        x = torch.tensor([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device)
        lower = 1.0
        upper = 4.0
        training = False

        torch.manual_seed(123)
        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)

        torch.manual_seed(123)
        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
        res = torch._decomp.decompositions.rrelu_with_noise(
            x,
            noise_res,
            lower,
            upper,
            training,
        )
        self.assertEqual(ref, res)
        self.assertEqual(noise_ref, noise_res)

        # Now with training=True:
        training = True

        torch.manual_seed(123)
        noise_ref = torch.zeros(x.shape, dtype=dtype, device=device)
        ref = torch.ops.aten.rrelu_with_noise(x, noise_ref, lower, upper, training)

        torch.manual_seed(123)
        noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
        res = torch._decomp.decompositions.rrelu_with_noise(
            x,
            noise_res,
            lower,
            upper,
            training,
        )
        self.assertEqual(ref, res)
        self.assertEqual(noise_ref, noise_res)

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @suppress_warnings
    @tf32_off()
    # only tests RNNs since we have py dispsatcher decomps for them
    @modules(
        filter(
            lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
            module_db,
        )
    )
    def test_rnn_decomp_module(self, device, dtype, module_info, training):
        module_cls = module_info.module_cls
        module_inputs = module_info.module_inputs_func(
            module_info,
            device=device,
            dtype=dtype,
            requires_grad=True,
            training=training,
        )
        for module_input in module_inputs:
            if module_input.forward_input is None:
                continue
            args, kwargs = (
                module_input.constructor_input.args,
                module_input.constructor_input.kwargs,
            )
            m = module_cls(*args, **kwargs)
            m.to(device).to(dtype)

            args, kwargs = (
                module_input.forward_input.args,
                module_input.forward_input.kwargs,
            )
            with self.DecompCrossRefMode(
                self, self.precision, self.rel_tol, dtype, run_all=True
            ), enable_python_dispatcher():
                decomp_out = m(*args, **kwargs)

            non_decomp_out = m(*args, **kwargs)
            # without this check, incorrect decomps at the python dispatcher level can still pass because
            # they're checking aten decomps at the torch_dispatch level
            self.assertEqual(decomp_out, non_decomp_out)

    def test_batch_norm_unflatten_weight_bias(self, device):
        # https://github.com/pytorch/pytorch/issues/100970
        shape = (1, 3, 2, 2)
        input = torch.randn(shape, device=device)
        weight = torch.randn((3, 1, 1, 1), device=device)
        bias = torch.randn(3, device=device)
        mean = torch.randn(3, device=device)
        var = torch.randn(3, device=device)
        res = torch._decomp.decompositions.native_batch_norm(
            input, weight, bias, mean, var, False, 1, 1e-05
        )
        self.assertEqual(shape, res[0].shape)

    def test_arange_graph(self, device):
        from torch.fx.experimental.proxy_tensor import make_fx

        def func(x, start):
            le = x.shape[-1]
            if start is None:
                a = torch.arange(le, dtype=torch.float32, device=x.device)
            else:
                a = torch.arange(start, le, dtype=torch.float32, device=x.device)
            return a

        pattern = r", device = device\(.+\), requires_grad = False"

        cfunc = make_fx(func, decomposition_table=decomposition_table)
        fx_g = cfunc(torch.rand(10, device=device), None)
        fx_g_code = fx_g.code.strip()
        # Remove device and requires_grad
        fx_g_code = re.sub(pattern, "", fx_g_code)
        self.assertExpectedInline(
            fx_g_code,
            """\
def forward(self, x_1, start_1):
    iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
    add = torch.ops.prims.add.default(mul, 0);  mul = None
    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
    return convert_element_type""",
        )

        fx_g = cfunc(torch.rand(10, device=device), 1)
        fx_g_code = fx_g.code.strip()
        # Remove device and requires_grad
        fx_g_code = re.sub(pattern, "", fx_g_code)
        self.assertExpectedInline(
            fx_g_code,
            """\
def forward(self, x_1, start_1):
    iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
    mul = torch.ops.prims.mul.default(iota, 1);  iota = None
    add = torch.ops.prims.add.default(mul, 1);  mul = None
    convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32);  add = None
    return convert_element_type""",
        )

    def test_masked_fill(self, device):
        from torch.fx.experimental.proxy_tensor import make_fx

        if torch.device(device).type not in [
            "xpu",
            "cuda",
            torch._C._get_privateuse1_backend_name(),
        ]:
            self.skipTest("only runs on XPU and CUDA and PrivateUse1.")

        def func(scores, mask, value):
            return scores.masked_fill(mask, value)

        scores_t = torch.tensor([1, 2, 3, 4], device=device)
        mask_t = torch.tensor([True, True, True, True], device=device)
        value_t = torch.tensor(0, dtype=scores_t.dtype)
        cfunc = make_fx(func, decomposition_table=decomposition_table)
        fx_g = cfunc(scores_t, mask_t, value_t)
        self.assertExpectedInline(
            fx_g.code.strip(),
            """\
def forward(self, scores_1, mask_1, value_1):
    where = torch.ops.prims.where.default(mask_1, value_1, scores_1);  mask_1 = value_1 = scores_1 = None
    return where""",
        )

    class DecompCrossRefMode(TorchDispatchMode):
        def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
            self.test_case = test_case
            self.saved_precision = saved_precision
            self.saved_rel_tol = saved_rel_tol
            self.test_dtype = dtype
            self.run_all = run_all

            # We check the correctness of each decomposition right after running it.
            # So, when we encounter a decomposition, we run the function normally, and
            # then run the decomposition, and ensure they're identical.
            self.called = set()
            self.decomposed = set()

        def __torch_dispatch__(self, func, types, args=(), kwargs=None):
            self.test_case.precision = self.saved_precision
            self.test_case.rel_tol = self.saved_rel_tol

            self.called.add(func)
            all_called[func] += 1

            # Stuff we shouldn't bother testing
            # (TODO: remove detach from the decomp table?)
            # N.b. Testing in-place ops would need dedicated logic
            in_place = func.name()[-1] == "_"
            ignored_ops = [
                torch.ops.aten.detach.default,
                # non-deterministic ops
                torch.ops.aten.empty.memory_format,
                torch.ops.aten.empty_like.default,
                torch.ops.aten.new_empty.default,
                torch.ops.aten.empty_strided.default,
                torch.ops.aten.new_empty_strided.default,
                torch.ops.aten.randn.default,
                torch.ops.aten.native_dropout.default,
            ]
            if (
                func not in decomposition_table
                or func in ignored_ops
                or torch.Tag.nondeterministic_seeded in func.tags
                or any_unsupported(args, kwargs)
                or in_place
            ):
                return func(*args, **kwargs)

            self.decomposed.add(func)
            all_decomposed.add(func)

            # We take 2 main strategies for verifying correctness/numerical stability of decompositions
            # The first one is simply tolerance checking between decomp_out and pytorch_out
            # However, for fp16/bf16 and reductions, this becomes very
            # finicky, as there are not many guarantees we can make.
            # So, for fp16/bf16, we instead compare the difference of
            # {decomp_out, pytorch_out_64} and {pytorch_out,
            # pytorch_out_64}. In other words, we compare how far the
            # decomposition and pytorch are from the "ground truth" (i.e.
            # fp64). If the decomposition results in more error, we error

            # We also decompose the decomposition recursively for
            # further coverage, as some paths not be exercised directly by
            # OpInfos (sadly) but just by other ops

            decomposition = decomposition_table[func]

            do_relative_check = self.test_dtype in [torch.float16, torch.bfloat16]
            if self.run_all:
                # Execute recursively via DFS, to find the root of a possible error first
                with self:
                    decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))
            else:
                decomp_out = pytree.tree_leaves(decomposition(*args, **kwargs))

            # At this stage we should not be decomposing an in-place op
            # We'd like to have decompositions that decompose out-of-place ops into out-of-place ops
            #  because decompositions are run after functionalisation and we would not like them to
            #  de-functionalise the graph, as that would break AoTAutograd
            # We run the real function *after* the decomposition to make sure that the
            # decomposition does not modify any of the inputs in-place. If it does
            # real_out should be differen than decom_out so we should catch this
            real_out_unflat = func(*args, **kwargs)
            real_out = pytree.tree_leaves(real_out_unflat)

            assert len(real_out) == len(decomp_out)

            if do_relative_check:
                upcast = partial(upcast_tensor, dtype=torch.float64)
                real_out_double, _ = tree_flatten(
                    func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
                )
                for i, (orig, decomp, ref) in enumerate(
                    zip(real_out, decomp_out, real_out_double)
                ):
                    if not isinstance(orig, torch.Tensor):
                        assert type(orig) == type(decomp)
                        assert orig == decomp
                        continue
                    op_assert_ref(
                        self.test_case,
                        func,
                        self.test_dtype,
                        i,
                        orig,
                        decomp,
                        ref,
                        args,
                        kwargs,
                    )
            else:
                for orig, decomp in zip(real_out, decomp_out):
                    if not isinstance(orig, torch.Tensor):
                        assert type(orig) == type(decomp)
                        assert orig == decomp
                        continue
                    op_assert_equal(
                        self.test_case,
                        func,
                        self.test_dtype,
                        orig,
                        decomp,
                        args,
                        kwargs,
                    )

            return real_out_unflat

    def check_decomposed(self, aten_name, mode):
        self.assertTrue(
            any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
            msg=(
                f"aten.{aten_name} was not decomposed, saw calls for: "
                f"{', '.join(map(str, list(mode.called)))}. If your op is  "
                f"CompositeImplicitAutograd you should skip this test "
                f"by updating CROSS_REF_EXCLUDE_SET."
            ),
        )

    @skipIfTorchDynamo("Test does not work with TorchDynamo")
    def do_cross_ref(self, device, dtype, op, *, run_all):
        test_keys = [
            (torch.device(device).type, dtype, op.name),
            (None, dtype, op.name),
            (None, None, op.name),
        ]
        if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
            self.skipTest(f"{op.name} in {dtype} not supported")

        skip_decomp_vjp = any(
            key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
        )

        requires_grad = (
            op.supports_autograd
            and dtype in op.supported_backward_dtypes(torch.device(device).type)
            # TODO: OpInfo really ought to error out for this case, but it's
            # not exercised in test_ops_gradients atm.  The problem is not
            # complex32 per-se (which is supported by data movement only ops)
            # but that when we do backwards we expect other ops like add to work
            and not dtype == torch.complex32
        )
        samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)

        aten_name = op.decomp_aten_name or op.aten_name

        func = op.get_op()

        def run_without_python_dispatcher(mode):
            return any(
                isinstance(op, torch._ops.OpOverload)
                and op.has_kernel_for_dispatch_key(
                    DispatchKey.CompositeImplicitAutograd
                )
                for op in mode.decomposed.union([func])
            )

        for sample_input in samples:
            if requires_grad:
                fn, primals = normalize_op_input_output(func, sample_input)
                primals = tree_map(
                    lambda x: x if isinstance(x, torch.Tensor) else x, primals
                )

                # Once https://github.com/pytorch/pytorch/pull/75965/ I can
                # store the called list on the mode object instance and no
                # explicit clearing is necessary as I will create a fresh mode
                # for each region
                with self.DecompCrossRefMode(
                    self, self.precision, self.rel_tol, dtype, run_all
                ) as mode, enable_python_dispatcher():
                    decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
                if run_without_python_dispatcher(mode):
                    # without this check, incorrect decomps at the python dispatcher level can still pass because
                    # they're checking aten decomps at the torch_dispatch level.
                    with self.DecompCrossRefMode(
                        self, self.precision, self.rel_tol, dtype, run_all
                    ) as mode:
                        decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
                if aten_name in decomposition_names:
                    self.check_decomposed(aten_name, mode)

                if not skip_decomp_vjp and (
                    op.aten_backward_name in decomposition_names or run_all
                ):
                    cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)

                    with self.DecompCrossRefMode(
                        self, self.precision, self.rel_tol, dtype, run_all
                    ) as mode, enable_python_dispatcher():
                        decomp_vjp_fn(cotangents)
                    if run_without_python_dispatcher(mode):
                        # without this check, incorrect decomps at the python dispatcher level can still pass because
                        # they're checking aten decomps at the torch_dispatch level.
                        with self.DecompCrossRefMode(
                            self, self.precision, self.rel_tol, dtype, run_all
                        ) as mode:
                            decomp_vjp_fn(cotangents)
                    if not run_all:
                        self.check_decomposed(op.aten_backward_name, mode)

            elif aten_name in decomposition_names or run_all:
                args = [sample_input.input] + list(sample_input.args)
                kwargs = sample_input.kwargs
                # A failure here might be because the decomposition for the op is wrong or because a
                # decomposition used by the particular op is wrong.
                with self.DecompCrossRefMode(
                    self, self.precision, self.rel_tol, dtype, run_all
                ) as mode, enable_python_dispatcher():
                    func(*args, **kwargs)

                if run_without_python_dispatcher(mode):
                    # without this check, incorrect decomps at the python dispatcher level can still pass because
                    # they're checking aten decomps at the torch_dispatch level.
                    with self.DecompCrossRefMode(
                        self, self.precision, self.rel_tol, dtype, run_all
                    ) as mode:
                        func(*args, **kwargs)

                if not run_all:
                    self.check_decomposed(aten_name, mode)
            else:
                assert op.supports_autograd
                self.skipTest(
                    "only backwards is decomposed, but dtype doesn't support AD"
                )


instantiate_device_type_tests(TestDecomp, globals())


class DecompOneOffTests(TestCase):
    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    def test_contiguous_softmax(self, device):
        size = (2, 4, 3, 3)
        stride = (9, 18, 3, 1)
        dtype = torch.float32

        x = torch.randn(size, dtype=dtype, device=device)
        x = torch.as_strided(x, size, stride)

        ref = torch.ops.aten._softmax(x, -1, False)
        res = torch._decomp.decompositions._softmax(x, -1, False)
        self.assertEqual(ref.stride(), res.stride())

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    def test_contiguous_log_softmax(self, device):
        size = (2, 4, 3, 3)
        stride = (9, 18, 3, 1)

        dtype = torch.float32
        x = torch.randn(size, dtype=dtype, device=device)
        x = torch.as_strided(x, size, stride)

        ref = torch.ops.aten._log_softmax(x, -1, False)
        res = torch._decomp.decompositions._log_softmax(x, -1, False)
        self.assertEqual(ref.stride(), res.stride())

    @onlyCUDA
    def test_exponential_non_inf(self, device):
        inp = torch.empty((4, 400, 256), device=device)

        with torch._dynamo.utils.preserve_rng_state():
            exp_ref = inp.exponential_()
        exp = torch._refs.exponential(inp)

        self.assertEqual(exp, exp_ref)
        self.assertFalse(exp.isinf().any())

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @skipIfCrossRef
    @onlyCUDA
    def test_amp_batch_norm_backward(self):
        device = "cuda"
        grad_out = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
        x = torch.randn((1, 2, 16, 16), dtype=torch.float16, device=device)
        weight = torch.randn((2,), dtype=torch.float32, device=device)
        rmean = torch.randn((2,), dtype=torch.float32, device=device)
        rvar = torch.randn((2,), dtype=torch.float32, device=device)
        mean = torch.randn((0,), dtype=torch.float32, device=device)

        ref = torch.ops.aten.native_batch_norm_backward(
            grad_out,
            x,
            weight,
            rmean,
            rvar,
            mean,
            mean,
            False,
            1e-05,
            [True, True, True],
        )
        res = torch._decomp.decompositions.native_batch_norm_backward(
            grad_out,
            x,
            weight,
            rmean,
            rvar,
            mean,
            mean,
            False,
            1e-05,
            [True, True, True],
        )
        for a, b in zip(ref, res):
            self.assertEqual(a.stride(), b.stride())
            self.assertEqual(a.dtype, b.dtype)

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    def test_elu_backward(self, device):
        size = (2, 4, 3, 3)
        dtype = torch.float32
        grad_out = torch.randn(size, dtype=dtype, device=device)
        out = torch.randn(size, dtype=dtype, device=device)

        ref = torch.ops.aten.elu_backward(grad_out, 1.0, 1, 1, True, out)
        res = torch._decomp.decompositions.elu_backward(grad_out, 1.0, 1, 1, True, out)
        self.assertEqual(ref, res)

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    def test_threshold_backward_dtype(self, device):
        grad = torch.randint(10, (4,), device=device)
        input_tensor = torch.randint(10, (4,), device=device)

        ref = torch.ops.aten.threshold_backward(grad, input_tensor, 1)
        res = torch._decomp.decompositions.threshold_backward(grad, input_tensor, 1)
        self.assertEqual(ref.dtype, res.dtype)

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyNativeDeviceTypes
    @skipIfCrossRef
    def test_weight_norm_interface(self, device):
        g = torch.randn((3, 10, 10), device=device)
        v = torch.randn((1, 1, 10), device=device)

        ref = torch.ops.aten._weight_norm_interface(g, v, 2)
        res = torch._decomp.decompositions._weight_norm_interface(g, v, 2)
        self.assertTrue(torch.allclose(ref[0], res[0]))
        self.assertTrue(torch.allclose(ref[1], res[1]))

        inp = torch.rand([30, 10], device=device)
        inp2 = torch.rand([30, 1], device=device)

        self.assertEqual(
            torch.ops.aten._weight_norm_interface(inp, inp2),
            torch._decomp.decompositions._weight_norm_interface(inp, inp2),
        )

    @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
    @onlyCPU
    @skipIfCrossRef
    @skipOps(
        "DecompOneOffTests",
        "test_sdpa",
        [
            xfail(
                "nn.functional.scaled_dot_product_attention",
                dtypes=[torch.half],
            ),
        ],
    )
    @ops(_sdpa_op_info)
    def test_sdpa(self, device, dtype, op):
        # SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
        # add support for float16 over there we should update this test as well.

        class ScaledDotProductAttention(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

            def forward(
                self, query_layer, key_layer, value_layer, mask=None, is_causal=True
            ):
                attn_output = op(
                    query_layer,
                    key_layer,
                    value_layer,
                    attn_mask=mask,
                    dropout_p=0.0,
                    is_causal=is_causal,
                )
                return attn_output

        query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
        key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
        value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
        masks = [None, torch.ones((1, 1, 100, 100), device=device, dtype=torch.bool)]

        atol, rtol = dtype_precisions[dtype]

        for mask in masks:
            is_causal = mask is None
            attention = ScaledDotProductAttention()
            decomposed_res = (
                torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
                    query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
                )
            )
            eager_res = op(
                query_layer,
                key_layer,
                value_layer,
                attn_mask=mask,
                dropout_p=0.0,
                is_causal=is_causal,
            )

            self.assertTrue(
                torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
            )


instantiate_device_type_tests(DecompOneOffTests, globals())


class HasDecompTest(TestCase):
    def setUp(self):
        super().setUp()
        self.maxDiff = None

    @staticmethod
    def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
        has_tensor_arg = any(
            "Tensor" in str(a.type)
            for a in itertools.chain(op._schema.arguments, op._schema.returns)
        )
        if not has_tensor_arg:
            return False

        try:
            # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
            return not op.has_kernel_for_dispatch_key(
                DispatchKey.CompositeImplicitAutograd
            )
        except RuntimeError as e:
            # has_key fails for some jit-registered ops, which shouldn't be
            # relevant here anyway
            if "does not exist" in str(e):
                return False
            raise

    def test_has_decomposition(self):
        def all_aten_overloads():
            for name in torch._C._dispatch_get_all_op_names():
                if not name.startswith("aten::"):
                    continue

                name = name[6:]
                if "." in name:
                    packet_name, overload_name = name.split(".")
                else:
                    packet_name, overload_name = name, "default"

                packet = getattr(aten, packet_name)
                assert isinstance(packet, torch._ops.OpOverloadPacket)
                op = getattr(packet, overload_name)
                yield op

        # This is for operators that are only registered in some CI
        # configurations, so would cause the test to fail
        allow_list = {aten.get_gradients.default}

        overloads_wanting_decomp = {
            op for op in all_aten_overloads() if self._can_appear_in_trace(op)
        }
        ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
        ops_missing_decomp -= allow_list
        self.assertExpected(
            "".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
        )

    def test_aten_core_operators(self):
        # If a decomposition isn't included in the core decompositions,
        # then it must decompose a core ATen operator.
        #
        # See NOTE [Core ATen Ops]
        #
        # If this test fails then either:
        # - Add the decomposition to torch._decomp.core_aten_decompositions,
        #   if decomposition should be used by inductor (not a core operator).
        # - Run this test again with EXPECTTEST_ACCEPT=1 to update the list of
        #   core ATen operators (and inductor will not use the decomposition).

        # Some decompositions are registered for CompositeImplicitAutograd
        # operators, which never appear in AOTAutograd's graph so are never used.
        useful_decomps = {
            op
            for op in decomposition_table.keys()
            if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
        }
        core_decomps = torch._decomp.core_aten_decompositions().keys()
        core_aten_ops = useful_decomps - core_decomps
        self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))


if __name__ == "__main__":
    run_tests()
