# Owner(s): ["module: autograd"]

import contextlib
import warnings

import numpy as np

import torch
from torch.library import _scoped_library, Library
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TestCase,
)


@contextlib.contextmanager
def autograd_fallback_mode(mode):
    prev = torch._C._get_autograd_fallback_mode()
    try:
        torch._C._set_autograd_fallback_mode(mode)
        yield
    finally:
        torch._C._set_autograd_fallback_mode(prev)


class TestAutogradFallback(TestCase):
    test_ns = "_test_autograd_fallback"

    def tearDown(self):
        if hasattr(torch.ops, self.test_ns):
            delattr(torch.ops, self.test_ns)
        if hasattr(self, "lib"):
            del self.lib.m
            del self.lib

    def get_op(self, name):
        return getattr(getattr(torch.ops, self.test_ns), name).default

    def get_lib(self):
        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
        self.lib = lib
        return lib

    @parametrize("mode", ("nothing", "warn"))
    def test_no_grad(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
            lib.impl("foo", lambda a, b, c: a + b + c, "CPU")
            op = self.get_op("foo")

            with warnings.catch_warnings():
                warnings.simplefilter("error")
                with torch.no_grad():
                    a = torch.randn([], requires_grad=True)
                    b = torch.randn([], requires_grad=True)
                    out = op(a, b, 1)
                self.assertFalse(out.requires_grad)

            with warnings.catch_warnings():
                warnings.simplefilter("error")
                a = torch.randn([])
                b = torch.randn([])
                out = op(a, b, 1)
                self.assertFalse(out.requires_grad)

    @parametrize("mode", ("nothing", "warn"))
    def test_no_autograd_kernel(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor a, Tensor b, int c) -> Tensor")
            op = self.get_op("foo")

            def foo_impl(a, b, c):
                result = a.detach().numpy() + b.detach().numpy() + c
                return torch.tensor(result)

            lib.impl("foo", foo_impl, "CPU")

            # Some inputs requiring grad
            a = torch.randn([], requires_grad=False)
            b = torch.randn([], requires_grad=True)
            out = op(a, b, 1).sum()
            with self._check_ctx(mode, mode_nothing_raises=True):
                out.backward()
            self.assertIsNone(b.grad)

    def _check_ctx(self, mode, *, mode_nothing_raises=False):
        if mode == "warn":
            return self.assertWarnsRegex(
                UserWarning, "an autograd kernel was not registered"
            )
        assert mode == "nothing"
        if mode_nothing_raises:
            return self.assertRaisesRegex(RuntimeError, "does not require grad")
        return contextlib.nullcontext()

    @parametrize("mode", ("nothing", "warn"))
    def test_no_autograd_kernel_inplace(self, mode):
        with autograd_fallback_mode(mode):
            # input modified in-place gets returned as output
            lib = self.get_lib()
            lib.define("foo(Tensor(a!) self, Tensor(b!) y) -> (Tensor(a!), Tensor(b!))")
            op = self.get_op("foo")

            def foo_impl(x, y):
                with torch.no_grad():
                    x.sin_()
                    y.cos_()
                return x, y

            lib.impl("foo", foo_impl, "CPU")

            x = torch.randn(3, requires_grad=True)
            w = x.clone()
            v = x.clone()
            y0 = w[0]
            y1 = v[1]
            z0, z1 = op(y0, y1)
            for tensor in [w, v, z0, z1, y0, y1]:
                with self._check_ctx(mode):
                    tensor.sum().backward(retain_graph=True)

            # no outputs: we don't do anything. Maybe we should in the future.
            # This is not a common failure mode.
            lib.define("bar(Tensor(a!) self) -> ()")
            op = self.get_op("bar")

            def bar_impl(x):
                with torch.no_grad():
                    x.sin_()

            lib.impl("bar", bar_impl, "CPU")
            with warnings.catch_warnings():
                warnings.simplefilter("error")
                x = torch.randn([], requires_grad=True)
                y = x.clone()
                z = op(y)
                y.backward()
                self.assertEqual(x.grad, torch.ones_like(x))

    @parametrize("mode", ("nothing", "warn"))
    def test_cpu_return_self(self, mode):
        with autograd_fallback_mode(mode):
            # To be clear, none of these situations are OK and will lead
            # to other problems down the line. We're testing them because
            # it is fairly common to actually do these things.
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
                lib.define("foo(Tensor self) -> Tensor")
                lib.impl("foo", lambda x: x, "CPU")
                op = self.get_op("foo")

                x = torch.randn(3, requires_grad=True)
                y = op(x).sum()
                with self._check_ctx(mode):
                    y.backward()
                    self.assertEqual(x.grad, torch.ones_like(x))

                lib.define("bar(Tensor(a!) self) -> Tensor(a!)")
                lib.impl("bar", lambda x: x, "CPU")
                op = self.get_op("bar")

                x = torch.randn(3, requires_grad=True)
                y = op(x).sum()
                with self._check_ctx(mode):
                    y.backward()
                    self.assertEqual(x.grad, torch.ones_like(x))

    @parametrize("mode", ("nothing", "warn"))
    def test_composite_registered_to_cpu(self, mode):
        with autograd_fallback_mode(mode):
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
                lib.define("foo(Tensor self) -> Tensor")
                lib.impl("foo", lambda x: x.sin().sum(), "CPU")
                op = self.get_op("foo")

                x = torch.randn(3, requires_grad=True)
                y = op(x)
                with self._check_ctx(mode):
                    y.backward()
                    self.assertEqual(x.grad, x.cos())

    @parametrize("mode", ("nothing", "warn"))
    def test_autograd_function_registered_to_cpu(self, mode):
        with autograd_fallback_mode(mode):
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
                lib.define("foo(Tensor self) -> Tensor")

                class NumpySin(torch.autograd.Function):
                    @staticmethod
                    def forward(ctx, x):
                        ctx.save_for_backward(x)
                        return torch.tensor(np.sin(x.cpu().numpy()))

                    @staticmethod
                    def backward(ctx, gx):
                        (x,) = ctx.saved_tensors
                        return gx * x.cos()

                lib.impl("foo", NumpySin.apply, "CPU")
                op = self.get_op("foo")

                x = torch.randn(3, requires_grad=True)
                y = op(x).sum()
                with self._check_ctx(mode):
                    y.backward()
                    self.assertEqual(x.grad, x.cos())

    @parametrize("mode", ("nothing", "warn"))
    def test_inplace_autograd_function_registered_to_cpu(self, mode):
        with autograd_fallback_mode(mode):
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
                lib.define("foo(Tensor(a!) self) -> Tensor(a!)")

                class NumpySin_(torch.autograd.Function):
                    @staticmethod
                    def forward(ctx, x):
                        ctx.save_for_backward(x.clone())
                        x_np = x.detach().numpy()
                        np.sin(x_np, out=x_np)
                        ctx.mark_dirty(x)
                        return x

                    @staticmethod
                    def backward(ctx, gx):
                        (x,) = ctx.saved_tensors
                        return gx * x.cos()

                lib.impl("foo", NumpySin_.apply, "CPU")
                op = self.get_op("foo")

                x = torch.randn(3, requires_grad=True)
                z = x.clone()
                w = z[0]
                y = op(w)

                expected = torch.zeros_like(x)
                expected[0] = x[0].cos()
                with self._check_ctx(mode):
                    (gx,) = torch.autograd.grad(
                        y, x, torch.ones_like(y), retain_graph=True
                    )
                    self.assertEqual(gx, expected)

                expected = torch.ones_like(x)
                expected[0] = x[0].cos()
                with self._check_ctx(mode):
                    (gx,) = torch.autograd.grad(z, x, torch.ones_like(z))
                    self.assertEqual(gx, expected)

    @parametrize("mode", ("nothing", "warn"))
    def test_inplace_on_tensor_that_does_not_require_grad(self, mode):
        # We don't do anything special (that is, we don't rebase history).
        # See NOTE [autograd fallback and in-place operations] for why
        with autograd_fallback_mode(mode):
            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
                # Correct usage of (a!)
                lib.define("foo(Tensor(a!) self, Tensor other) -> Tensor(a!)")

                def foo_impl(x, y):
                    x_d = x.detach()
                    y = y.detach()
                    x_d.add_(y)
                    return x

                lib.impl("foo", foo_impl, "CPU")
                foo = self.get_op("foo")

                # Incorrect usage of (a!): user doesn't return tensor as-is
                lib.define("bar(Tensor(a!) self, Tensor other) -> Tensor(a!)")

                def bar_impl(x, y):
                    x_d = x.detach()
                    y = y.detach()
                    x_d.add_(y)
                    return x_d.clone()

                lib.impl("bar", bar_impl, "CPU")
                bar = self.get_op("bar")

                # User mutated input tensor but didn't return it.
                lib.define("baz(Tensor(a!) self, Tensor other) -> ()")

                def baz_impl(x, y):
                    x_d = x.detach()
                    y = y.detach()
                    x_d.add_(y)

                lib.impl("baz", baz_impl, "CPU")
                baz = self.get_op("baz")

                # Test in-place on non-view
                for op in (foo, bar, baz):
                    x = torch.randn(3)
                    y = torch.randn(3, requires_grad=True)
                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
                        z = x.clone()
                        op(z, y)
                        torch.autograd.grad(z, y, torch.ones_like(z), allow_unused=True)

                # Test in-place on view
                for op in (foo, bar, baz):
                    x = torch.randn(3)
                    y = torch.randn(3, requires_grad=True)
                    with self.assertRaisesRegex(RuntimeError, "does not require grad"):
                        z = x[:]
                        op(z, y)
                        torch.autograd.grad(z, x, torch.ones_like(z), allow_unused=True)

    @parametrize("mode", ("nothing", "warn"))
    def test_post_autograd_returns_leaf(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor a) -> (Tensor, Tensor)")
            op = self.get_op("foo")

            lib.impl(
                "foo", lambda a: (a.clone(), a.clone().detach().requires_grad_()), "CPU"
            )
            x = torch.randn(3, requires_grad=True)
            y, z = op(x)
            with self._check_ctx(mode):
                z.sum().backward()

    @parametrize("mode", ("nothing", "warn"))
    def test_undefined_inputs_outputs(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
            op = self.get_op("foo")

            def foo_impl(a, b):
                return None, b.clone()

            lib.impl("foo", foo_impl, "CPU")

            x = torch.randn(3, requires_grad=True)
            # NB: PyTorch dispatcher treats "None" as undefined Tensor.
            y, z = op(None, x)
            with self._check_ctx(mode):
                z.sum().backward()

    @parametrize("mode", ("nothing", "warn"))
    def test_undefined_grads(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor)")
            op = self.get_op("foo")

            def foo_impl(a, b):
                return a.sin(), b.cos()

            lib.impl("foo", foo_impl, "CPU")

            x = torch.randn(3, requires_grad=True)
            y = torch.randn(3)
            w, z = op(x, y)
            w = torch._C._functions.UndefinedGrad()(w)
            z = torch._C._functions.UndefinedGrad()(z)
            with self._check_ctx(mode):
                (z + w).sum().backward()

    @parametrize("mode", ("nothing", "warn"))
    def test_base_does_not_require_grad(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor(a!) x) -> Tensor(a!)")
            op = self.get_op("foo")

            def foo_impl(a):
                with torch.no_grad():
                    return a.zero_()

            lib.impl("foo", foo_impl, "CPU")
            x = torch.randn(3)
            y = x[:]
            y.requires_grad_()
            w = y[:]
            self.assertTrue(w._base is x)

            # Hook should be registered on w, but not w._base
            op(w)
            with self._check_ctx(mode):
                w.sum().backward()

    @parametrize("mode", ("nothing", "warn"))
    def test_post_autograd_returns_mix_of_requires_grad_tensors(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor a, Tensor b) -> (Tensor, Tensor, Tensor)")
            op = self.get_op("foo")

            def foo_impl(a, b):
                with torch.no_grad():
                    x = a.clone()
                    z = b.clone()
                y = a * b
                return x, y, z

            lib.impl("foo", foo_impl, "CPU")
            a = torch.randn(3, requires_grad=True)
            b = torch.randn(3, requires_grad=True)
            x, y, z = op(a, b)

            with self._check_ctx(mode, mode_nothing_raises=True):
                torch.autograd.grad(
                    x, (a, b), torch.ones_like(x), allow_unused=True, retain_graph=True
                )

            with self._check_ctx(mode, mode_nothing_raises=False):
                torch.autograd.grad(
                    y, (a, b), torch.ones_like(y), allow_unused=True, retain_graph=True
                )

            with self._check_ctx(mode, mode_nothing_raises=True):
                torch.autograd.grad(
                    z, (a, b), torch.ones_like(z), allow_unused=True, retain_graph=True
                )

    @parametrize("mode", ("nothing", "warn"))
    def test_supports_tensor_lists(self, mode):
        with autograd_fallback_mode(mode):
            lib = self.get_lib()
            lib.define("foo(Tensor[] a) -> Tensor[]")
            op = self.get_op("foo")

            def foo_impl(a):
                x, y, z = a
                with torch.no_grad():
                    return x + y + z, x * y * z

            lib.impl("foo", foo_impl, "CPU")
            x = torch.randn(3, requires_grad=True)
            y = torch.randn(1, requires_grad=True)
            z = torch.randn(2, 1, requires_grad=True)
            a, b = op([x, y, z])
            with self._check_ctx(mode, mode_nothing_raises=True):
                torch.autograd.grad(
                    a,
                    (x, y, z),
                    torch.ones_like(a),
                    allow_unused=True,
                    retain_graph=True,
                )
            with self._check_ctx(mode, mode_nothing_raises=True):
                torch.autograd.grad(
                    b,
                    (x, y, z),
                    torch.ones_like(b),
                    allow_unused=True,
                    retain_graph=True,
                )


instantiate_parametrized_tests(TestAutogradFallback)

if __name__ == "__main__":
    run_tests()
