# Owner(s): ["oncall: jit"]

import os
import sys
import unittest
import warnings
from typing import Dict, List, Optional

import torch


# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase


if __name__ == "__main__":
    raise RuntimeError(
        "This test file is not meant to be run directly, use:\n\n"
        "\tpython test/test_jit.py TESTNAME\n\n"
        "instead."
    )


class TestScriptModuleInstanceAttributeTypeAnnotation(JitTestCase):
    # NB: There are no tests for `Tuple` or `NamedTuple` here. In fact,
    # reassigning a non-empty Tuple to an attribute previously typed
    # as containing an empty Tuple SHOULD fail. See note in `_check.py`

    def test_annotated_falsy_base_type(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: int = 0

            def forward(self, x: int):
                self.x = x
                return 1

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), (1,))
        assert len(w) == 0

    def test_annotated_nonempty_container(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: List[int] = [1, 2, 3]

            def forward(self, x: List[int]):
                self.x = x
                return 1

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert len(w) == 0

    def test_annotated_empty_tensor(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: torch.Tensor = torch.empty(0)

            def forward(self, x: torch.Tensor):
                self.x = x
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), (torch.rand(2, 3),))
        assert len(w) == 0

    def test_annotated_with_jit_attribute(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.jit.Attribute([], List[int])

            def forward(self, x: List[int]):
                self.x = x
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert len(w) == 0

    def test_annotated_class_level_annotation_only(self):
        class M(torch.nn.Module):
            x: List[int]

            def __init__(self) -> None:
                super().__init__()
                self.x = []

            def forward(self, y: List[int]):
                self.x = y
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert len(w) == 0

    def test_annotated_class_level_annotation_and_init_annotation(self):
        class M(torch.nn.Module):
            x: List[int]

            def __init__(self) -> None:
                super().__init__()
                self.x: List[int] = []

            def forward(self, y: List[int]):
                self.x = y
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert len(w) == 0

    def test_annotated_class_level_jit_annotation(self):
        class M(torch.nn.Module):
            x: List[int]

            def __init__(self) -> None:
                super().__init__()
                self.x: List[int] = torch.jit.annotate(List[int], [])

            def forward(self, y: List[int]):
                self.x = y
                return self.x

        with warnings.catch_warnings(record=True) as w:
            self.checkModule(M(), ([1, 2, 3],))
        assert len(w) == 0

    def test_annotated_empty_list(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: List[int] = []

            def forward(self, x: List[int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    @unittest.skipIf(
        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
    )
    def test_annotated_empty_list_lowercase(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: list[int] = []

            def forward(self, x: list[int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    def test_annotated_empty_dict(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: Dict[str, int] = {}

            def forward(self, x: Dict[str, int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    @unittest.skipIf(
        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
    )
    def test_annotated_empty_dict_lowercase(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: dict[str, int] = {}

            def forward(self, x: dict[str, int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    def test_annotated_empty_optional(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x: Optional[str] = None

            def forward(self, x: Optional[str]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Wrong type for attribute assignment", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    def test_annotated_with_jit_empty_list(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.jit.annotate(List[int], [])

            def forward(self, x: List[int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    @unittest.skipIf(
        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
    )
    def test_annotated_with_jit_empty_list_lowercase(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.jit.annotate(list[int], [])

            def forward(self, x: list[int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    def test_annotated_with_jit_empty_dict(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.jit.annotate(Dict[str, int], {})

            def forward(self, x: Dict[str, int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    @unittest.skipIf(
        sys.version_info[:2] < (3, 9), "Requires lowercase static typing (Python 3.9+)"
    )
    def test_annotated_with_jit_empty_dict_lowercase(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.jit.annotate(dict[str, int], {})

            def forward(self, x: dict[str, int]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Tried to set nonexistent attribute", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    def test_annotated_with_jit_empty_optional(self):
        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = torch.jit.annotate(Optional[str], None)

            def forward(self, x: Optional[str]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Wrong type for attribute assignment", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())

    def test_annotated_with_torch_jit_import(self):
        from torch import jit

        class M(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.x = jit.annotate(Optional[str], None)

            def forward(self, x: Optional[str]):
                self.x = x
                return 1

        with self.assertRaisesRegexWithHighlight(
            RuntimeError, "Wrong type for attribute assignment", "self.x = x"
        ):
            with self.assertWarnsRegex(
                UserWarning,
                "doesn't support "
                "instance-level annotations on "
                "empty non-base types",
            ):
                torch.jit.script(M())
