# Owner(s): ["module: decompositions"]

from functools import partial
from itertools import product
import unittest

import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
                                                  set_default_dtype)
from torch.testing._internal.common_device_type import (
    instantiate_device_type_tests,
    onlyCUDA,
    dtypes,
    OpDTypes,
)
from torch.testing._internal.common_methods_invocations import (
    op_db,
)
from torch.testing._internal.common_device_type import (
    ops,
)

from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
import torch._prims as prims
from torch._prims_common import CUDARngStateHelper
from torch._prims.executor import make_traced
import torch._refs as refs


if TEST_SCIPY:
    import scipy.special

NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"

class TestPrims(TestCase):
    @onlyCUDA
    @dtypes(torch.float32)
    def test_broadcast_in_dim(self, device, dtype):
        def _wrapper(a, b, broadcast_dimensions):
            return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)

        traced = make_traced(_wrapper)
        make_arg = partial(make_tensor, device=device, dtype=dtype)

        for executor in ('aten',):
            fn = partial(traced, executor=executor)
            # Same shape
            shape = (5, 5)
            a = make_arg(shape)
            b = make_arg(shape, low=0.0, high=0.0)
            result = fn(a, b, (0, 1))

            self.assertEqual(result.shape, a.shape)
            self.assertTrue(result.is_contiguous)
            self.assertEqual(a, result)

            # Error input: reordering dims
            with self.assertRaises(Exception):
                result = fn(a, b, (1, 0))

            # Adding outermost dimensions
            a = make_arg((5, 5))
            b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
            result = fn(a, b, (2, 3))

            self.assertEqual(result.shape, b.shape)
            self.assertEqual(a.broadcast_to(b.shape), result)

            # Expands
            a = make_arg((1, 5, 1))
            b = make_arg((3, 5, 7), low=0.0, high=0.0)
            result = fn(a, b, (0, 1, 2))

            self.assertEqual(result.shape, b.shape)
            self.assertEqual(a.expand_as(result), result)

            # Unsqueezes
            a = make_arg((1, 2, 3))
            b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
            result = fn(a, b, (0, 1, 3))

            self.assertEqual(result.shape, b.shape)
            self.assertEqual(a.unsqueeze(2), result)

    @onlyCUDA
    @dtypes(torch.float32)
    def test_broadcast_in_dim_sum(self, device, dtype):
        def _wrapper(a):
            a_sum = prims.sum(a, [0, 1])
            a_bc = prims.broadcast_in_dim(a_sum, [], [])
            return a_bc

        traced = make_traced(_wrapper)
        make_arg = partial(make_tensor, device=device, dtype=dtype)

        for executor in ('aten',):
            fn = partial(traced, executor=executor)
            shape = (5, 5)
            a = make_arg(shape)
            result = fn(a)

            self.assertEqual(result.shape, ())
            self.assertTrue(result.is_contiguous)
            self.assertEqual(_wrapper(a), result)

    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
    @dtypes(torch.float64, torch.long)
    def test_cbrt_prim(self, device, dtype):
        make_arg = partial(make_tensor, device=device, dtype=dtype)
        batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
        shapes = [(), (0,), (1,), (5,)]

        # Sets the default dtype to NumPy's default dtype of double
        with set_default_dtype(torch.double):
            # Tested here, as this OP is not currently exposed or tested in ATen
            for b, s in product(batches, shapes):
                x = make_arg(b + s)
                y = prims.cbrt(x)

                x_np = x.cpu().numpy()
                y_np = scipy.special.cbrt(x_np)

                self.assertEqual(y, y_np, exact_device=False)

    @dtypes(torch.float32)
    def test_collapse(self, device, dtype):
        t = torch.rand(2, 2, 2)
        dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)]
        expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]

        for (start, end), shape in zip(dim_ranges, expected_shapes):
            expect = t.reshape(shape)

            copy = prims.collapse(t, start, end)
            self.assertEqual(copy, expect)
            self.assertFalse(copy._is_view())

            view = prims.collapse_view(t, start, end)
            self.assertEqual(view, expect)
            self.assertTrue(view._is_view())

        t_discontig = t.transpose(0, 1)
        with self.assertRaises(ValueError, msg="no such view exists"):
            view = prims.collapse_view(t_discontig, 0, 2)

        copy = prims.collapse(t_discontig, 0, 1)
        self.assertEqual(copy, t_discontig.reshape(4, 2))

        error_dims = [(-1, 1), (0, 3), (1, -1)]
        for start, end in error_dims:
            for fn in [prims.collapse, prims.collapse_view]:
                with self.assertRaises(AssertionError):
                    fn(t, start, end)


    def test_aten_overload_to_prims(self, device):
        # This test is to ensure that the torch.ops.aten calls are replaced with refs
        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode

        a = torch.randn(3, 3, device=device)

        def func(a):
            return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a))

        with TorchRefsMode():
            gm = make_fx(func)(a)

        # Check that all call_function nodes are prims
        call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
        all_prims_namespace = all(
            node.target.name().startswith("prims") for node in call_function_nodes
        )
        self.assertTrue(all_prims_namespace)

    @onlyCUDA
    @dtypes(torch.float32)
    @parametrize("correction", [0, 1])
    def test_var(self, device, dtype, correction):
        def _wrapper(a):
            return prims.var(a, [0, 1], correction=correction)

        traced = make_traced(_wrapper)
        make_arg = partial(make_tensor, device=device, dtype=dtype)

        for executor in ('aten',):
            fn = partial(traced, executor=executor)
            shape = (5, 5)
            a = make_arg(shape)
            result = fn(a)

            self.assertEqual(result.shape, ())
            self.assertTrue(result.is_contiguous)
            self.assertEqual(_wrapper(a), result)

    @dtypes(torch.float32)
    def test_memory_format_strides(self, device, dtype):
        shapes = (
            (),
            (0,),
            (1,),
            (5),
            (1, 0),
            (1, 1),
            (3, 7),
            (3, 0, 2),
            (1, 1, 2),
            (4, 1, 1),
            (7, 8, 9),
        )

        channels_last_shapes = (
            (0, 0, 0, 0),
            (1, 0, 3, 0),
            (0, 2, 3, 5),
            (2, 2, 2, 0),
            (5, 4, 3, 2),
            (8, 8, 7, 2),
            (9, 1, 3, 1),
            (4, 5, 8, 7)
        )

        channels_last_3d_shapes = (
            (0, 8, 7, 9, 2),
            (5, 0, 7, 9, 2),
            (5, 0, 7, 9, 0),
            (5, 8, 7, 9, 2),
            (5, 1, 7, 9, 2),
            (5, 1, 7, 9, 1),
        )

        pairs = (
            (shapes, torch.contiguous_format),
            (channels_last_shapes, torch.contiguous_format),
            (channels_last_3d_shapes, torch.contiguous_format),
            (channels_last_shapes, torch.channels_last),
            (channels_last_3d_shapes, torch.channels_last_3d),
        )

        for shapes, memory_format in pairs:
            for shape in shapes:
                # tests empty
                expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
                actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
                self.assertEqual(expected.stride(), actual.stride())

                # tests clone
                a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
                expected = torch.clone(a, memory_format=memory_format)
                actual = torch.clone(a, memory_format=memory_format)
                self.assertEqual(expected.stride(), actual.stride())

                # tests contiguous
                a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
                expected = a.contiguous(memory_format=memory_format)
                actual = refs.contiguous(a, memory_format=memory_format)
                self.assertEqual(expected.stride(), actual.stride())

    @dtypes(torch.float32)
    def test_reshape_view_method(self, device, dtype):
        make_arg = partial(make_tensor, device=device, dtype=dtype)
        a = make_arg((5, 5))
        new_shape = 1, 5, 1, 5
        result_eager = a.reshape(*new_shape)
        result_refs = refs.reshape(a, *new_shape)
        self.assertEqual(result_eager, result_refs)

        result_eager = a.view(*new_shape)
        result_refs = refs.view(a, *new_shape)
        self.assertEqual(result_eager, result_refs)


    @onlyCUDA
    @dtypes(torch.float32)
    def test_philox_rand(self, device, dtype):
        sizes = (1000, 1000000)  # offsets of 4 and 8
        repeats = 2  # Checks multiple rand calls results with multiple philox_rand calls
        for size in sizes:
            torch.cuda.manual_seed(123)
            references = []
            results = []
            rng_states = []
            for _ in range(repeats):
                rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple())
                references.append(torch.rand(size, device=device, dtype=dtype))

            torch.cuda.manual_seed(123)
            for idx in range(repeats):
                seed, offset = rng_states[idx]
                result, _ = torch.ops.rngprims.philox_rand((size,),
                                                           seed=seed,
                                                           offset=offset,
                                                           stride=None,
                                                           device=device,
                                                           dtype=dtype)
                results.append(result)

            for a, b in zip(references, results):
                self.assertEqual(a, b)


    @dtypes(torch.float32)
    def test_functional_rng_wrappers(self, device, dtype):

        torch.manual_seed(123)
        ref1 = torch.rand(10, device=device, dtype=dtype)
        ref2 = torch.rand(10, device=device, dtype=dtype)


        torch.manual_seed(123)
        rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
        rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)

        res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
        res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)

        self.assertEqual(ref1, res1)
        self.assertEqual(ref2, res2)
        self.assertEqual(ref1, res3)
        self.assertEqual(ref2, res4)

class TestPrimsBasic(TestCase):
    def test_torch_ops(self):
        r = make_tensor((2,), device='cpu', dtype=torch.float)
        self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))

        r = LoggingTensor(r)
        with capture_logs() as logs:
            log_input("input", r)
            prims.sin(r)
        self.assertExpectedInline('\n'.join(logs), """\
$0: f32[2] = input('input')
$1: f32[2] = torch._ops.prims.sin.default($0)""")

    def test_mul_complex(self):
        prims.mul(torch.randn(2), 1 + 1j)

    def test_clone_complex(self):
        with torch._dispatch.python.enable_python_dispatcher():
            x = torch.randn(4, dtype=torch.complex64, device='meta').conj()
            out = x + 1

    def test_check_deprecation_warning(self):
        with self.assertWarnsRegex(FutureWarning, 'will be removed in the future'):
            torch._prims_common.check(True, lambda: 'message')


instantiate_device_type_tests(TestPrims, globals())


class TestRefs(TestCase):
    @dtypes(torch.float32)
    def test_constant_pad_nd_memory_format(self, device, dtype):
        # Test memory format is preserved in unambiguous cases
        for mf, ndim in (
                (torch.channels_last, 4),
                (torch.contiguous_format, 4),
                (torch.channels_last_3d, 5),
                (torch.contiguous_format, 5),
        ):
            a = torch.zeros([2] * ndim).to(memory_format=mf)
            res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim))
            self.assertTrue(res.is_contiguous(memory_format=mf))

        # Ambiguous cases

        # is_channels_last_ and is_contiguous_, results in channels_last output
        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1))
        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
        self.assertTrue(a.is_contiguous())
        actual = refs.constant_pad_nd(a, pad=[1] * 8)
        expect = torch.constant_pad_nd(a, pad=[1] * 8)
        self.assertEqual(actual.stride(), expect.stride())
        self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last))

        # is_channels_last_contiguous_ but not is_channels_last_, results in
        # contiguous output
        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1))
        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
        self.assertTrue(a.is_contiguous())
        actual = refs.constant_pad_nd(a, pad=[1] * 8)
        expect = torch.constant_pad_nd(a, pad=[1] * 8)
        self.assertEqual(actual.stride(), expect.stride())
        self.assertTrue(actual.is_contiguous())

    def test_unbind(self):
        # If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py.
        # So can't put this test into common_methods_invocations.py.
        a = torch.rand([3, 0, 4])
        actual = refs.unbind(a, 1)
        expect = torch.unbind(a, 1)
        self.assertEqual(actual, expect)

    def test_logspace_with_complex_input(self):
        actual = refs.logspace(2, 10 + 5j, steps=5)
        expect = torch.logspace(2, 10 + 5j, steps=5)
        self.assertEqual(actual, expect)

    def test_linspace_with_complex_input(self):
        actual = refs.linspace(2, 10 + 5j, steps=5)
        expect = torch.linspace(2, 10 + 5j, steps=5)
        self.assertEqual(actual, expect)

    # From https://github.com/pytorch/pytorch/issues/109558
    def test_infinite_loop_from_py_dispatcher(self):
        # enables prim decomps
        with torch._dispatch.python.enable_python_dispatcher():
            x = torch.ones(4)
            y = x.to(device="meta")

    def test_inferred_tags(self):
        self.assertEqual(torch.ops.prims.normal.default.tags, (torch.Tag.nondeterministic_seeded, torch.Tag.pt2_compliant_tag))



instantiate_device_type_tests(TestRefs, globals())


class TestDecomp(TestCase):
    @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
    def test_decomposition_method_vararg(self, device, dtype, op):
        # some ops have vararg variants for the methods. this tests it.
        # we don't have tests for varargs in OpInfo, so we need to
        # improvise this a bit.
        # The rule for general functions (the special cases being e.g. tensor
        # creation functions taking shapes) is that things can be vararg
        # if the method has only one argument of sequence type.
        # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1)
        #      as well as t.permute([0, 2, 1])
        #      when the signature in native_functions.yaml
        #      shows arguments Tensor self, IntList dims
        # we might need to adjust things for the factory functions or
        # have them do their own test
        from torch.fx.experimental.proxy_tensor import make_fx
        from torch._prims.context import TorchRefsMode

        # filter out empty tuple as that cannot be the varargs
        sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False)
                         if (si.args[-1] if si.args else si.input))

        # just run one test, we assume there is a suitable one in the tests
        sample_input = next(sample_inputs)
        all_args = (sample_input.input,) + sample_input.args

        # in general, the methods take varargs and not (always?) the function
        # variants, the exception to this rule are the factory functions
        if op.is_factory_function:
            fn = op.op
        else:
            fn = op.method_variant
        with TorchRefsMode():
            gm = make_fx(fn)(*all_args[:-1], *all_args[-1])

        # in case we add random factory functions
        torch.manual_seed(1)
        res = gm(*all_args[:-1], *all_args[-1])
        torch.manual_seed(1)
        expected = fn(*all_args[:-1], *all_args[-1])
        self.assertEqual(res, expected)


instantiate_device_type_tests(TestDecomp, globals())


if __name__ == "__main__":
    run_tests()
