import sys
import textwrap
from pathlib import Path


def check(path):
    """Check a test file for common issues with pytest->pytorch conversion."""
    print(path.name)
    print("=" * len(path.name), "\n")

    src = path.read_text().split("\n")
    for num, line in enumerate(src):
        if is_comment(line):
            continue

        # module level test functions
        if line.startswith("def test"):
            report_violation(line, num, header="Module-level test function")

        # test classes must inherit from TestCase
        if line.startswith("class Test") and "TestCase" not in line:
            report_violation(
                line, num, header="Test class does not inherit from TestCase"
            )

        # last vestiges of pytest-specific stuff
        if "pytest.mark" in line:
            report_violation(line, num, header="pytest.mark.something")

        for part in ["pytest.xfail", "pytest.skip", "pytest.param"]:
            if part in line:
                report_violation(line, num, header=f"stray {part}")

        if textwrap.dedent(line).startswith("@parametrize"):
            # backtrack to check
            nn = num
            for nn in range(num, -1, -1):
                ln = src[nn]
                if "class Test" in ln:
                    # hack: large indent => likely an inner class
                    if len(ln) - len(ln.lstrip()) < 8:
                        break
            else:
                report_violation(line, num, "off-class parametrize")
            if not src[nn - 1].startswith("@instantiate_parametrized_tests"):
                report_violation(
                    line, num, f"missing instantiation of parametrized tests in {ln}?"
                )


def is_comment(line):
    return textwrap.dedent(line).startswith("#")


def report_violation(line, lineno, header):
    print(f">>>> line {lineno} : {header}\n {line}\n")


if __name__ == "__main__":
    argv = sys.argv
    if len(argv) != 2:
        raise ValueError("Usage : python check_tests_conform path/to/file/or/dir")

    path = Path(argv[1])

    if path.is_dir():
        # run for all files in the directory (no subdirs)
        for this_path in path.glob("test*.py"):
            #   breakpoint()
            check(this_path)
    else:
        check(path)
