# Owner(s): ["module: dynamo"]

import sys

import pytest

import torch._numpy as tnp


def pytest_configure(config):
    config.addinivalue_line("markers", "slow: very slow tests")


def pytest_addoption(parser):
    parser.addoption("--runslow", action="store_true", help="run slow tests")
    parser.addoption("--nonp", action="store_true", help="error when NumPy is accessed")


class Inaccessible:
    def __getattribute__(self, attr):
        raise RuntimeError(f"Using --nonp but accessed np.{attr}")


def pytest_sessionstart(session):
    if session.config.getoption("--nonp"):
        sys.modules["numpy"] = Inaccessible()


def pytest_generate_tests(metafunc):
    """
    Hook to parametrize test cases
    See https://docs.pytest.org/en/6.2.x/parametrize.html#pytest-generate-tests

    The logic here allows us to test with both NumPy-proper and torch._numpy.
    Normally we'd just test torch._numpy, e.g.

        import torch._numpy as np
        ...
        def test_foo():
            np.array([42])
            ...

    but this hook allows us to test NumPy-proper as well, e.g.

        def test_foo(np):
            np.array([42])
            ...

    np is a pytest parameter, which is either NumPy-proper or torch._numpy. This
    allows us to sanity check our own tests, so that tested behaviour is
    consistent with NumPy-proper.

    pytest will have test names respective to the library being tested, e.g.

        $ pytest --collect-only
        test_foo[torch._numpy]
        test_foo[numpy]

    """
    np_params = [tnp]

    try:
        import numpy as np
    except ImportError:
        pass
    else:
        if not isinstance(np, Inaccessible):  # i.e. --nonp was used
            np_params.append(np)

    if "np" in metafunc.fixturenames:
        metafunc.parametrize("np", np_params)


def pytest_collection_modifyitems(config, items):
    if not config.getoption("--runslow"):
        skip_slow = pytest.mark.skip(reason="slow test, use --runslow to run")
        for item in items:
            if "slow" in item.keywords:
                item.add_marker(skip_slow)
