# Owner(s): ["oncall: jit"]

import os
import sys
import unittest
from itertools import product

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing import FileCheck


try:
    import torchvision

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")

# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )

activations = [
    F.celu,
    F.elu,
    F.hardsigmoid,
    F.hardswish,
    F.hardtanh,
    F.leaky_relu,
    F.relu,
    F.relu6,
    F.rrelu,
    F.selu,
    F.silu,
]


class TestFunctionalToInplaceActivation(JitTestCase):
    def test_check_no_type_promotion(self):
        dtypes = [
            torch.bool,
            torch.int8,
            torch.int16,
            torch.int32,
            torch.int64,
            torch.float32,
            torch.float64,
        ]
        # restore_mutation.h contains a mapping from activation operators
        # to whether they allow type conversion. Use this checking to
        # guard the mapping, and if any later change breaks the assumption
        # we need to update the mapping correspondingly.
        for activation, dtype in product(activations, dtypes):
            inp = torch.normal(0, 5, size=(4, 4)).to(dtype)
            try:
                out = activation(inp)
                self.assertEqual(dtype, out.dtype)
            except RuntimeError:
                # Skip the not implemented error
                pass

    def test_functional_to_inplace_activation(self):
        for activation in activations:

            def test_basic(x):
                y = x + 1
                z = activation(y)
                return z

            fn = torch.jit.script(test_basic)
            self.run_pass("inline", fn.graph)
            self.run_pass("constant_propagation", fn.graph)
            FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)
            self.run_pass("functional_to_inplace_activation", fn.graph)
            FileCheck().check_not(f"aten::{activation.__name__}(").run(fn.graph)
            FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
            inp = torch.rand([2, 2])
            self.assertEqual(fn(inp), test_basic(inp))

    def test_no_functional_to_inplace(self):
        # inplace conversion should not happen because sigmoid may
        # perform type conversion
        def test1():
            y = torch.ones([2, 2])
            z = torch.sigmoid(y)
            return z

        fn = torch.jit.script(test1)
        self.run_pass("functional_to_inplace_activation", fn.graph)
        FileCheck().check_not("aten::sigmoid_").run(fn.graph)

        # inplace conversion should not happen because y is alias
        # the input x
        def test2(x):
            y = x[0]
            z = torch.relu(y)
            return z

        fn = torch.jit.script(test2)
        self.run_pass("functional_to_inplace_activation", fn.graph)
        FileCheck().check_not("aten::relu_").run(fn.graph)

        # inplace conversion should not happen because self.x is
        # at the global scope
        class Test3(nn.Module):
            def __init__(self, x):
                super().__init__()
                self.x = x

            def forward(self):
                y = torch.relu(self.x)
                return y

        fn = torch.jit.script(Test3(torch.rand([2, 2])).eval())
        self.run_pass("functional_to_inplace_activation", fn.graph)
        FileCheck().check_not("aten::relu_").run(fn.graph)

    @skipIfNoTorchVision
    def test_resnet18_correctness(self):
        model = torchvision.models.resnet18()
        frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
        (
            N,
            C,
            H,
            W,
        ) = (
            10,
            3,
            224,
            224,
        )
        inp = torch.randn(N, C, H, W)
        self.run_pass("functional_to_inplace_activation", frozen_model.graph)
        self.assertEqual(model(inp), frozen_model(inp))


class TestInplaceToFunctionalActivation(JitTestCase):
    def test_inplace_to_functional_activation(self):
        for activation in activations:

            def test_basic(x):
                y = x + 1
                activation(y, inplace=True)
                return y

            fn = torch.jit.script(test_basic)
            self.run_pass("inline", fn.graph)
            self.run_pass("constant_propagation", fn.graph)
            FileCheck().check(f"aten::{activation.__name__}_").run(fn.graph)
            self.run_pass("inplace_to_functional_activation", fn.graph)
            FileCheck().check_not(f"aten::{activation.__name__}_").run(fn.graph)
            FileCheck().check(f"aten::{activation.__name__}(").run(fn.graph)

        for activation in [
            torch.relu_,
            torch.sigmoid_,
            torch.tanh_,
        ]:

            def test_basic(x):
                y = x + 1
                activation(y)
                return y

            fn = torch.jit.script(test_basic)
            self.run_pass("inline", fn.graph)
            self.run_pass("constant_propagation", fn.graph)
            FileCheck().check(f"aten::{activation.__name__}").run(fn.graph)
            self.run_pass("inplace_to_functional_activation", fn.graph)
            FileCheck().check_not(f"aten::{activation.__name__}").run(fn.graph)
            FileCheck().check(f"aten::{activation.__name__[:-1]}(").run(fn.graph)

            inp = torch.rand([2, 2])
            self.assertEqual(fn(inp), test_basic(inp))

    @skipIfNoTorchVision
    def test_resnet18_correctness(self):
        model = torchvision.models.resnet18()
        frozen_model = torch.jit.freeze(torch.jit.script(model.eval()))
        (
            N,
            C,
            H,
            W,
        ) = (
            10,
            3,
            224,
            224,
        )
        inp = torch.randn(N, C, H, W)
        self.run_pass("inplace_to_functional_activation", frozen_model.graph)
        self.assertEqual(model(inp), frozen_model(inp))
