from mypy.plugin import Plugin
from mypy.plugins.common import add_attribute_to_class
from mypy.types import NoneType, UnionType


class SympyPlugin(Plugin):
    def get_base_class_hook(self, fullname: str):
        if fullname == "sympy.core.basic.Basic":
            return add_assumptions
        return None


def add_assumptions(ctx) -> None:
    # Generated by list(sys.modules['sympy.core.assumptions']._assume_defined)
    # (do not import sympy to speedup mypy plugin load time)
    assumptions = [
        "hermitian",
        "prime",
        "noninteger",
        "negative",
        "antihermitian",
        "infinite",
        "finite",
        "irrational",
        "extended_positive",
        "nonpositive",
        "odd",
        "algebraic",
        "integer",
        "rational",
        "extended_real",
        "nonnegative",
        "transcendental",
        "extended_nonzero",
        "extended_negative",
        "composite",
        "complex",
        "imaginary",
        "nonzero",
        "zero",
        "even",
        "positive",
        "polar",
        "extended_nonpositive",
        "extended_nonnegative",
        "real",
        "commutative",
    ]
    for a in assumptions:
        add_attribute_to_class(
            ctx.api,
            ctx.cls,
            f"is_{a}",
            UnionType([ctx.api.named_type("builtins.bool"), NoneType()]),
        )


def plugin(version: str):
    return SympyPlugin
