#!/usr/bin/env python3
"""
This lint verifies that every Python test file (file that matches test_*.py or
*_test.py in the test folder) has a main block which raises an exception or
calls run_tests to ensure that the test will be run in OSS CI.

Takes ~2 minuters to run without the multiprocessing, probably overkill.
"""

from __future__ import annotations

import argparse
import json
import multiprocessing as mp
from enum import Enum
from typing import NamedTuple

import libcst as cst
import libcst.matchers as m


LINTER_CODE = "TEST_HAS_MAIN"


class HasMainVisiter(cst.CSTVisitor):
    def __init__(self) -> None:
        super().__init__()
        self.found = False

    def visit_Module(self, node: cst.Module) -> bool:
        name = m.Name("__name__")
        main = m.SimpleString('"__main__"') | m.SimpleString("'__main__'")
        run_test_call = m.Call(
            func=m.Name("run_tests") | m.Attribute(attr=m.Name("run_tests"))
        )
        # Distributed tests (i.e. MultiProcContinuousTest) calls `run_rank`
        # instead of `run_tests` in main
        run_rank_call = m.Call(
            func=m.Name("run_rank") | m.Attribute(attr=m.Name("run_rank"))
        )
        raise_block = m.Raise()

        # name == main or main == name
        if_main1 = m.Comparison(
            name,
            [m.ComparisonTarget(m.Equal(), main)],
        )
        if_main2 = m.Comparison(
            main,
            [m.ComparisonTarget(m.Equal(), name)],
        )
        for child in node.children:
            if m.matches(child, m.If(test=if_main1 | if_main2)):
                if m.findall(child, raise_block | run_test_call | run_rank_call):
                    self.found = True
                    break

        return False


class LintSeverity(str, Enum):
    ERROR = "error"
    WARNING = "warning"
    ADVICE = "advice"
    DISABLED = "disabled"


class LintMessage(NamedTuple):
    path: str | None
    line: int | None
    char: int | None
    code: str
    severity: LintSeverity
    name: str
    original: str | None
    replacement: str | None
    description: str | None


def check_file(filename: str) -> list[LintMessage]:
    lint_messages = []

    with open(filename) as f:
        file = f.read()
        v = HasMainVisiter()
        cst.parse_module(file).visit(v)
        if not v.found:
            message = (
                "Test files need to have a main block which either calls run_tests "
                + "(to ensure that the tests are run during OSS CI) or raises an exception "
                + "and added to the blocklist in test/run_test.py"
            )
            lint_messages.append(
                LintMessage(
                    path=filename,
                    line=None,
                    char=None,
                    code=LINTER_CODE,
                    severity=LintSeverity.ERROR,
                    name="[no-main]",
                    original=None,
                    replacement=None,
                    description=message,
                )
            )
    return lint_messages


def main() -> None:
    parser = argparse.ArgumentParser(
        description="test files should have main block linter",
        fromfile_prefix_chars="@",
    )
    parser.add_argument(
        "filenames",
        nargs="+",
        help="paths to lint",
    )

    args = parser.parse_args()

    pool = mp.Pool(8)
    lint_messages = pool.map(check_file, args.filenames)
    pool.close()
    pool.join()

    flat_lint_messages = []
    for sublist in lint_messages:
        flat_lint_messages.extend(sublist)

    for lint_message in flat_lint_messages:
        print(json.dumps(lint_message._asdict()), flush=True)


if __name__ == "__main__":
    main()
