import torch


def check_error(desc, fn, *required_substrings):
    try:
        fn()
    except Exception as e:
        error_message = e.args[0]
        print("=" * 80)
        print(desc)
        print("-" * 80)
        print(error_message)
        print()
        for sub in required_substrings:
            assert sub in error_message
        return
    raise AssertionError(f"given function ({desc}) didn't raise an error")


check_error("Wrong argument types", lambda: torch.FloatStorage(object()), "object")

check_error(
    "Unknown keyword argument", lambda: torch.FloatStorage(content=1234.0), "keyword"
)

check_error(
    "Invalid types inside a sequence",
    lambda: torch.FloatStorage(["a", "b"]),
    "list",
    "str",
)

check_error("Invalid size type", lambda: torch.FloatStorage(1.5), "float")

check_error(
    "Invalid offset", lambda: torch.FloatStorage(torch.FloatStorage(2), 4), "2", "4"
)

check_error(
    "Negative offset", lambda: torch.FloatStorage(torch.FloatStorage(2), -1), "2", "-1"
)

check_error(
    "Invalid size",
    lambda: torch.FloatStorage(torch.FloatStorage(3), 1, 5),
    "2",
    "1",
    "5",
)

check_error(
    "Negative size",
    lambda: torch.FloatStorage(torch.FloatStorage(3), 1, -5),
    "2",
    "1",
    "-5",
)

check_error("Invalid index type", lambda: torch.FloatStorage(10)["first item"], "str")


def assign():
    torch.FloatStorage(10)[1:-1] = "1"


check_error("Invalid value type", assign, "str")

check_error(
    "resize_ with invalid type", lambda: torch.FloatStorage(10).resize_(1.5), "float"
)

check_error(
    "fill_ with invalid type", lambda: torch.IntStorage(10).fill_("asdf"), "str"
)

# TODO: frombuffer
