# Owner(s): ["oncall: export"]
import unittest

import torch
from functorch.experimental import control_flow
from torch import Tensor
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export.verifier import SpecViolationError, Verifier
from torch.export import export
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase


@unittest.skipIf(not is_dynamo_supported(), "dynamo isn't supported")
class TestVerifier(TestCase):
    def test_verifier_basic(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y

        f = Foo()

        ep = export(f, (torch.randn(100), torch.randn(100)))

        verifier = Verifier()
        verifier.check(ep)

    def test_verifier_call_module(self) -> None:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 10)

            def forward(self, x: Tensor) -> Tensor:
                return self.linear(x)

        gm = torch.fx.symbolic_trace(M())

        verifier = Verifier()
        with self.assertRaises(SpecViolationError):
            verifier._check_graph_module(gm)

    def test_verifier_no_functional(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y

        f = Foo()

        ep = export(f, (torch.randn(100), torch.randn(100)))
        for node in ep.graph.nodes:
            if node.target == torch.ops.aten.add.Tensor:
                node.target = torch.ops.aten.add_.Tensor

        verifier = Verifier()
        with self.assertRaises(SpecViolationError):
            verifier.check(ep)

    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
    def test_verifier_higher_order(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x + y

                def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x - y

                return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])

        f = Foo()

        ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))

        verifier = Verifier()
        verifier.check(ep)

    @unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
    def test_verifier_nested_invalid_module(self) -> None:
        class Foo(torch.nn.Module):
            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                def true_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x + y

                def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                    return x - y

                return control_flow.cond(x.sum() > 2, true_fn, false_fn, [x, y])

        f = Foo()

        ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))
        for node in ep.graph_module.true_graph_0.graph.nodes:
            if node.target == torch.ops.aten.add.Tensor:
                node.target = torch.ops.aten.add_.Tensor

        verifier = Verifier()
        with self.assertRaises(SpecViolationError):
            verifier.check(ep)

    def test_ep_verifier_basic(self) -> None:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear = torch.nn.Linear(10, 10)

            def forward(self, x: Tensor) -> Tensor:
                return self.linear(x)

        ep = export(M(), (torch.randn(10, 10),))
        ep.validate()

    def test_ep_verifier_invalid_param(self) -> None:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.register_parameter(
                    name="a", param=torch.nn.Parameter(torch.randn(100))
                )

            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y + self.a

        ep = export(M(), (torch.randn(100), torch.randn(100)))

        # Parameter doesn't exist in the state dict
        ep.graph_signature.input_specs[0] = InputSpec(
            kind=InputKind.PARAMETER, arg=TensorArgument(name="p_a"), target="bad_param"
        )
        with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
            ep.validate()

        # Add non-torch.nn.Parameter parameter to the state dict
        ep.state_dict["bad_param"] = torch.randn(100)
        with self.assertRaisesRegex(
            SpecViolationError, "not an instance of torch.nn.Parameter"
        ):
            ep.validate()

    def test_ep_verifier_invalid_buffer(self) -> None:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.a = torch.tensor(3.0)

            def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
                return x + y + self.a

        ep = export(M(), (torch.randn(100), torch.randn(100)))

        # Buffer doesn't exist in the state dict
        ep.graph_signature.input_specs[0] = InputSpec(
            kind=InputKind.BUFFER,
            arg=TensorArgument(name="c_a"),
            target="bad_buffer",
            persistent=True,
        )
        with self.assertRaisesRegex(SpecViolationError, "not in the state dict"):
            ep.validate()

    def test_ep_verifier_buffer_mutate(self) -> None:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))

                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))

            def forward(self, x1, x2):
                # Use the parameter, buffers, and both inputs in the forward method
                output = (
                    x1 + self.my_parameter
                ) * self.my_buffer1 + x2 * self.my_buffer2

                # Mutate one of the buffers (e.g., increment it by 1)
                self.my_buffer2.add_(1.0)
                return output

        ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))
        ep.validate()

    def test_ep_verifier_invalid_output(self) -> None:
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()

                self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))

                self.my_buffer1 = torch.nn.Buffer(torch.tensor(3.0))
                self.my_buffer2 = torch.nn.Buffer(torch.tensor(4.0))

            def forward(self, x1, x2):
                # Use the parameter, buffers, and both inputs in the forward method
                output = (
                    x1 + self.my_parameter
                ) * self.my_buffer1 + x2 * self.my_buffer2

                # Mutate one of the buffers (e.g., increment it by 1)
                self.my_buffer2.add_(1.0)
                return output

        ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))

        output_node = list(ep.graph.nodes)[-1]
        output_node.args = (
            (
                output_node.args[0][0],
                next(iter(ep.graph.nodes)),
                output_node.args[0][1],
            ),
        )

        with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
            ep.validate()


if __name__ == "__main__":
    run_tests()
