# Owner(s): ["module: typing"]
# based on NumPy numpy/typing/tests/test_typing.py

import itertools
import os
import re
import shutil
import unittest
from collections import defaultdict
from threading import Lock
from typing import Dict, IO, List, Optional

from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    TestCase,
)


try:
    from mypy import api
except ImportError:
    NO_MYPY = True
else:
    NO_MYPY = False


DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "typing"))
REVEAL_DIR = os.path.join(DATA_DIR, "reveal")
PASS_DIR = os.path.join(DATA_DIR, "pass")
FAIL_DIR = os.path.join(DATA_DIR, "fail")
MYPY_INI = os.path.join(DATA_DIR, os.pardir, os.pardir, "mypy.ini")
CACHE_DIR = os.path.join(DATA_DIR, ".mypy_cache")


def _key_func(key: str) -> str:
    """Split at the first occurance of the ``:`` character.

    Windows drive-letters (*e.g.* ``C:``) are ignored herein.
    """
    drive, tail = os.path.splitdrive(key)
    return os.path.join(drive, tail.split(":", 1)[0])


def _strip_filename(msg: str) -> str:
    """Strip the filename from a mypy message."""
    _, tail = os.path.splitdrive(msg)
    return tail.split(":", 1)[-1]


def _run_mypy() -> Dict[str, List[str]]:
    """Clears the cache and run mypy before running any of the typing tests."""
    if os.path.isdir(CACHE_DIR):
        shutil.rmtree(CACHE_DIR)

    rc: Dict[str, List[str]] = {}
    for directory in (REVEAL_DIR, PASS_DIR, FAIL_DIR):
        # Run mypy
        stdout, stderr, _ = api.run(
            [
                "--show-absolute-path",
                "--config-file",
                MYPY_INI,
                "--cache-dir",
                CACHE_DIR,
                directory,
            ]
        )
        assert not stderr, stderr
        stdout = stdout.replace("*", "")

        # Parse the output
        iterator = itertools.groupby(stdout.split("\n"), key=_key_func)
        rc.update((k, list(v)) for k, v in iterator if k)
    return rc


def get_test_cases(directory):
    for root, _, files in os.walk(directory):
        for fname in files:
            if fname.startswith("disabled_"):
                continue
            if os.path.splitext(fname)[-1] == ".py":
                fullpath = os.path.join(root, fname)
                yield fullpath


_FAIL_MSG1 = """Extra error at line {}
Extra error: {!r}
"""

_FAIL_MSG2 = """Error mismatch at line {}
Expected error: {!r}
Observed error: {!r}
"""


def _test_fail(
    path: str, error: str, expected_error: Optional[str], lineno: int
) -> None:
    if expected_error is None:
        raise AssertionError(_FAIL_MSG1.format(lineno, error))
    elif error not in expected_error:
        raise AssertionError(_FAIL_MSG2.format(lineno, expected_error, error))


def _construct_format_dict():
    dct = {
        "ModuleList": "torch.nn.modules.container.ModuleList",
        "AdaptiveAvgPool2d": "torch.nn.modules.pooling.AdaptiveAvgPool2d",
        "AdaptiveMaxPool2d": "torch.nn.modules.pooling.AdaptiveMaxPool2d",
        "Tensor": "torch._tensor.Tensor",
        "Adagrad": "torch.optim.adagrad.Adagrad",
        "Adam": "torch.optim.adam.Adam",
    }
    return dct


#: A dictionary with all supported format keys (as keys)
#: and matching values
FORMAT_DICT: Dict[str, str] = _construct_format_dict()


def _parse_reveals(file: IO[str]) -> List[str]:
    """Extract and parse all ``"  # E: "`` comments from the passed file-like object.

    All format keys will be substituted for their respective value from `FORMAT_DICT`,
    *e.g.* ``"{Tensor}"`` becomes ``"torch.tensor.Tensor"``.
    """
    string = file.read().replace("*", "")

    # Grab all `# E:`-based comments
    comments_array = [str.partition("  # E: ")[2] for str in string.split("\n")]
    comments = "/n".join(comments_array)

    # Only search for the `{*}` pattern within comments,
    # otherwise there is the risk of accidently grabbing dictionaries and sets
    key_set = set(re.findall(r"\{(.*?)\}", comments))
    kwargs = {
        k: FORMAT_DICT.get(k, f"<UNRECOGNIZED FORMAT KEY {k!r}>") for k in key_set
    }
    fmt_str = comments.format(**kwargs)

    return fmt_str.split("/n")


_REVEAL_MSG = """Reveal mismatch at line {}

Expected reveal: {!r}
Observed reveal: {!r}
"""


def _test_reveal(path: str, reveal: str, expected_reveal: str, lineno: int) -> None:
    if reveal not in expected_reveal:
        raise AssertionError(_REVEAL_MSG.format(lineno, expected_reveal, reveal))


@unittest.skipIf(NO_MYPY, reason="Mypy is not installed")
class TestTyping(TestCase):
    _lock = Lock()
    _cached_output: Optional[Dict[str, List[str]]] = None

    @classmethod
    def get_mypy_output(cls) -> Dict[str, List[str]]:
        with cls._lock:
            if cls._cached_output is None:
                cls._cached_output = _run_mypy()
            return cls._cached_output

    @parametrize(
        "path",
        get_test_cases(PASS_DIR),
        name_fn=lambda b: os.path.relpath(b, start=PASS_DIR),
    )
    def test_success(self, path) -> None:
        output_mypy = self.get_mypy_output()
        if path in output_mypy:
            msg = "Unexpected mypy output\n\n"
            msg += "\n".join(_strip_filename(v) for v in output_mypy[path])
            raise AssertionError(msg)

    @parametrize(
        "path",
        get_test_cases(FAIL_DIR),
        name_fn=lambda b: os.path.relpath(b, start=FAIL_DIR),
    )
    def test_fail(self, path):
        __tracebackhide__ = True

        with open(path) as fin:
            lines = fin.readlines()

        errors = defaultdict(lambda: "")

        output_mypy = self.get_mypy_output()
        self.assertIn(path, output_mypy)
        for error_line in output_mypy[path]:
            error_line = _strip_filename(error_line)
            match = re.match(
                r"(?P<lineno>\d+):(?P<colno>\d+): (error|note): .+$",
                error_line,
            )
            if match is None:
                raise ValueError(f"Unexpected error line format: {error_line}")
            lineno = int(match.group("lineno"))
            errors[lineno] += f"{error_line}\n"

        for i, line in enumerate(lines):
            lineno = i + 1
            if line.startswith("#") or (" E:" not in line and lineno not in errors):
                continue

            target_line = lines[lineno - 1]
            self.assertIn(
                "# E:", target_line, f"Unexpected mypy output\n\n{errors[lineno]}"
            )
            marker = target_line.split("# E:")[-1].strip()
            expected_error = errors.get(lineno)
            _test_fail(path, marker, expected_error, lineno)

    @parametrize(
        "path",
        get_test_cases(REVEAL_DIR),
        name_fn=lambda b: os.path.relpath(b, start=REVEAL_DIR),
    )
    def test_reveal(self, path):
        __tracebackhide__ = True

        with open(path) as fin:
            lines = _parse_reveals(fin)

        output_mypy = self.get_mypy_output()
        assert path in output_mypy
        for error_line in output_mypy[path]:
            match = re.match(
                r"^.+\.py:(?P<lineno>\d+):(?P<colno>\d+): note: .+$",
                error_line,
            )
            if match is None:
                raise ValueError(f"Unexpected reveal line format: {error_line}")
            lineno = int(match.group("lineno")) - 1
            assert "Revealed type is" in error_line

            marker = lines[lineno]
            _test_reveal(path, marker, error_line, 1 + lineno)


instantiate_parametrized_tests(TestTyping)

if __name__ == "__main__":
    run_tests()
