# Owner(s): ["module: unknown"]

import itertools
import logging
import re

import torch
from torch import nn
from torch.ao.pruning import (
    BaseSparsifier,
    FakeSparsity,
    NearlyDiagonalSparsifier,
    WeightNormSparsifier,
)
from torch.nn.utils.parametrize import is_parametrized
from torch.testing._internal.common_pruning import (
    ImplementedSparsifier,
    MockSparseLinear,
    SimpleLinear,
)
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


class TestBaseSparsifier(TestCase):
    def test_constructor(self):
        # Cannot instantiate the abstract base
        self.assertRaises(TypeError, BaseSparsifier)
        # Can instantiate the model with no configs
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(test=3)
        sparsifier.prepare(model, config=None)
        assert len(sparsifier.groups) == 5
        sparsifier.step()
        # Can instantiate the model with configs
        sparsifier = ImplementedSparsifier(test=3)
        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
        assert len(sparsifier.groups) == 1
        assert sparsifier.groups[0]["tensor_fqn"] == "linear1.weight"
        assert "test" in sparsifier.groups[0]
        assert sparsifier.groups[0]["test"] == 3

    def test_prepare_config(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(test=3)
        # Make sure there are no parametrizations before `prepare`
        assert not hasattr(model.seq[0], "parametrizations")
        assert not hasattr(model.linear1, "parametrizations")
        assert not hasattr(model.linear2, "parametrizations")
        sparsifier.prepare(
            model,
            config=[
                {"tensor_fqn": "seq.0.weight", "test": 42},
                # No 'linear1' to make sure it will be skipped in the sparsification
                {"tensor_fqn": "linear2.weight"},
            ],
        )
        assert len(sparsifier.groups) == 2
        # Check if default argument is not assigned if explicit
        assert sparsifier.groups[0]["tensor_fqn"] == "seq.0.weight"
        assert sparsifier.groups[0]["test"] == 42
        # Check if FQN and module are pointing to the same location
        assert sparsifier.groups[1]["tensor_fqn"] == "linear2.weight"
        assert sparsifier.groups[1]["module"] == model.linear2
        # Check if parameterizations are attached
        assert hasattr(model.seq[0], "parametrizations")
        assert not hasattr(model.linear1, "parametrizations")
        assert hasattr(model.linear2, "parametrizations")

    def test_step(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(test=3)
        sparsifier.enable_mask_update = True
        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
        sparsifier.step()
        assert torch.all(model.linear1.parametrizations.weight[0].mask[0] == 0)

    def test_state_dict(self):
        step_count = 3
        model0 = SimpleLinear()
        sparsifier0 = ImplementedSparsifier(test=3)
        sparsifier0.prepare(model0, [{"tensor_fqn": "linear1.weight"}])
        mask = model0.linear1.parametrizations["weight"][0].mask
        mask.data = torch.arange(mask.shape[0] * mask.shape[1]).reshape(mask.shape)
        for step in range(step_count):
            sparsifier0.step()
        state_dict = sparsifier0.state_dict()

        # Check the expected keys in the state_dict
        assert "state" in state_dict
        assert "step_count" in state_dict["state"]["linear1.weight"]
        assert state_dict["state"]["linear1.weight"]["step_count"] == 3
        assert "groups" in state_dict
        assert "test" in state_dict["groups"][0]
        assert "tensor_fqn" in state_dict["groups"][0]
        assert state_dict["groups"][0]["tensor_fqn"] == "linear1.weight"

        # Check loading static_dict creates an equivalent model
        model1 = SimpleLinear()
        sparsifier1 = ImplementedSparsifier()
        sparsifier1.prepare(model1, None)

        assert sparsifier0.state != sparsifier1.state

        # Make sure the masks are different in the beginning
        for mg in sparsifier0.groups:
            if mg["tensor_fqn"] == "linear1.weight":
                mask0 = mg["module"].parametrizations.weight[0].mask
        for mg in sparsifier1.groups:
            if mg["tensor_fqn"] == "linear1.weight":
                mask1 = mg["module"].parametrizations.weight[0].mask
        self.assertNotEqual(mask0, mask1)

        sparsifier1.load_state_dict(state_dict)

        # Make sure the states are loaded, and are correct
        assert sparsifier0.state == sparsifier1.state

        # Make sure the masks (and all dicts) are the same after loading
        assert len(sparsifier0.groups) == len(sparsifier1.groups)
        for idx in range(len(sparsifier0.groups)):
            mg0 = sparsifier0.groups[idx]
            mg1 = sparsifier1.groups[idx]
            for key in mg0.keys():
                assert key in mg1
                if key == "module":
                    # We cannot compare modules as they are different
                    param0 = mg0[key].parametrizations.weight[0]
                    param1 = mg1[key].parametrizations.weight[0]
                    assert hasattr(param0, "mask")
                    assert hasattr(param1, "mask")
                    self.assertEqual(param0.__dict__, param1.__dict__)
                else:
                    assert mg0[key] == mg1[key]

    def test_convert(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(test=3)
        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
        new_model = sparsifier.convert(
            model, mapping={nn.Linear: MockSparseLinear}, inplace=False
        )

        assert isinstance(new_model.linear1, MockSparseLinear)
        assert isinstance(new_model.seq[0], nn.Linear)
        assert isinstance(new_model.linear2, nn.Linear)

    def test_mask_squash(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(test=3)
        sparsifier.prepare(model, [{"tensor_fqn": "linear1.weight"}])
        assert hasattr(model.linear1.parametrizations.weight[0], "mask")
        assert is_parametrized(model.linear1, "weight")
        assert not is_parametrized(model.seq[0], "weight")

        sparsifier.squash_mask()
        assert not is_parametrized(model.seq[0], "weight")
        assert not is_parametrized(model.linear1, "weight")

    def test_mask_squash_with_params1(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
        sparsifier.prepare(
            model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
        )
        sparsifier.squash_mask(
            params_to_keep_per_layer={"linear1": ("foo", "bar"), "seq.0": ("baz",)}
        )
        assert not is_parametrized(model.seq[0], "weight")
        assert not is_parametrized(model.linear1, "weight")
        assert hasattr(model.seq[0], "sparse_params")
        assert hasattr(model.linear1, "sparse_params")
        assert model.seq[0].sparse_params.get("foo", None) is None
        assert model.seq[0].sparse_params.get("bar", None) is None
        assert model.seq[0].sparse_params.get("baz", None) == 1
        assert model.linear1.sparse_params.get("foo", None) == 3
        assert model.linear1.sparse_params.get("bar", None) == 2
        assert model.linear1.sparse_params.get("baz", None) is None

    def test_mask_squash_with_params2(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
        sparsifier.prepare(
            model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
        )
        sparsifier.squash_mask(params_to_keep=("foo", "bar"))
        assert not is_parametrized(model.seq[0], "weight")
        assert not is_parametrized(model.linear1, "weight")
        assert hasattr(model.seq[0], "sparse_params")
        assert hasattr(model.linear1, "sparse_params")
        assert model.seq[0].sparse_params.get("foo", None) == 3
        assert model.seq[0].sparse_params.get("bar", None) == 2
        assert model.seq[0].sparse_params.get("baz", None) is None
        assert model.linear1.sparse_params.get("foo", None) == 3
        assert model.linear1.sparse_params.get("bar", None) == 2
        assert model.linear1.sparse_params.get("baz", None) is None

    def test_mask_squash_with_params3(self):
        model = SimpleLinear()
        sparsifier = ImplementedSparsifier(foo=3, bar=2, baz=1)
        sparsifier.prepare(
            model, [{"tensor_fqn": "linear1.weight"}, {"tensor_fqn": "seq.0.weight"}]
        )
        sparsifier.squash_mask(
            params_to_keep=("foo", "bar"), params_to_keep_per_layer={"seq.0": ("baz",)}
        )
        assert not is_parametrized(model.seq[0], "weight")
        assert not is_parametrized(model.linear1, "weight")
        assert hasattr(model.seq[0], "sparse_params")
        assert hasattr(model.linear1, "sparse_params")
        assert model.seq[0].sparse_params.get("foo", None) == 3
        assert model.seq[0].sparse_params.get("bar", None) == 2
        assert model.seq[0].sparse_params.get("baz", None) == 1
        assert model.linear1.sparse_params.get("foo", None) == 3
        assert model.linear1.sparse_params.get("bar", None) == 2
        assert model.linear1.sparse_params.get("baz", None) is None


class TestWeightNormSparsifier(TestCase):
    def test_constructor(self):
        model = SimpleLinear()
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        for g in sparsifier.groups:
            assert isinstance(g["module"], nn.Linear)
            # The groups are unordered
            assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")

    def test_step(self):
        model = SimpleLinear()
        sparsifier = WeightNormSparsifier(sparsity_level=0.5)
        sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
        for g in sparsifier.groups:
            # Before step
            module = g["module"]
            assert (
                1.0 - module.parametrizations["weight"][0].mask.mean()
            ) == 0  # checking sparsity level is 0
        sparsifier.enable_mask_update = True
        sparsifier.step()
        self.assertAlmostEqual(
            model.linear1.parametrizations["weight"][0].mask.mean().item(),
            0.5,
            places=2,
        )
        for g in sparsifier.groups:
            # After step
            module = g["module"]
            assert (
                1.0 - module.parametrizations["weight"][0].mask.mean()
            ) > 0  # checking sparsity level has increased
        # Test if the mask collapses to all zeros if the weights are randomized
        iters_before_collapse = 1000
        for _ in range(iters_before_collapse):
            model.linear1.weight.data = torch.randn(model.linear1.weight.shape)
            sparsifier.step()
        for g in sparsifier.groups:
            # After step
            module = g["module"]
            assert (
                1.0 - module.parametrizations["weight"][0].mask.mean()
            ) > 0  # checking sparsity level did not collapse

    def test_step_2_of_4(self):
        model = SimpleLinear()
        sparsifier = WeightNormSparsifier(
            sparsity_level=1.0, sparse_block_shape=(1, 4), zeros_per_block=2
        )
        sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])
        sparsifier.step()
        # make sure the sparsity level is approximately 50%
        mask = model.linear1.parametrizations["weight"][0].mask.to(
            torch.float
        )  # mean works on float only
        self.assertAlmostEqual(mask.mean().item(), 0.5, places=2)
        # Make sure each block has exactly 50% zeros
        module = sparsifier.groups[0]["module"]
        mask = module.parametrizations["weight"][0].mask
        for row in mask:
            for idx in range(0, len(row), 4):
                block = row[idx : idx + 4]
                block, _ = block.sort()
                assert (block[:2] == 0).all()
                assert (block[2:] != 0).all()

    def test_prepare(self):
        model = SimpleLinear()
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        for g in sparsifier.groups:
            module = g["module"]
            # Check mask exists
            assert hasattr(module.parametrizations["weight"][0], "mask")
            # Check parametrization exists and is correct
            assert is_parametrized(module, "weight")
            assert type(module.parametrizations.weight[0]) == FakeSparsity

    def test_mask_squash(self):
        model = SimpleLinear()
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        sparsifier.squash_mask()
        for g in sparsifier.groups:
            module = g["module"]
            assert not is_parametrized(module, "weight")
            assert not hasattr(module, "mask")

    def test_sparsity_levels(self):
        sparsity_levels = [-1.0, 0.0, 0.5, 1.0, 2.0]
        sparse_block_shapes = [(1, 1), (1, 4), (2, 2), (4, 1)]
        zeros_per_blocks = [0, 1, 2, 3, 4]

        testcases = itertools.tee(
            itertools.product(sparsity_levels, sparse_block_shapes, zeros_per_blocks)
        )
        # Create a config and model with all the testcases
        model = nn.Sequential()
        sparsifier = WeightNormSparsifier()

        sparsity_per_layer_config = []
        p = re.compile(r"[-\.\s]")
        for sl, sbs, zpb in testcases[0]:
            # Make sure the number of zeros is not > values in a block
            if zpb > sbs[0] * sbs[1]:
                continue
            layer_name = f"{sl}_{sbs}_{zpb}"
            layer_name = p.sub("_", layer_name)

            layer = nn.Linear(12, 12, bias=False)
            layer.weight = nn.Parameter(torch.ones(12, 12))
            model.add_module(layer_name, layer)
            config = {
                "tensor_fqn": layer_name + ".weight",
                "sparsity_level": sl,
                "sparse_block_shape": sbs,
                "zeros_per_block": zpb,
            }
            sparsity_per_layer_config.append(config)

        sparsifier.prepare(model, sparsity_per_layer_config)
        sparsifier.step()
        sparsifier.squash_mask()
        model.eval()

        for sl, sbs, zpb in testcases[1]:
            if zpb > sbs[0] * sbs[1]:
                continue
            layer_name = f"{sl}_{sbs}_{zpb}"
            layer_name = p.sub("_", layer_name)
            layer = getattr(model, layer_name)

            # Level of sparsity is achieved
            sparse_mask = (layer.weight == 0).float()
            if zpb == 0:
                assert sparse_mask.mean() == 0
            else:
                # Ratio of individual zeros in the tensor
                true_sl = min(max(sl, 0.0), 1.0)
                true_sl = true_sl * zpb / sbs[0] / sbs[1]
                assert sparse_mask.mean() == true_sl


class TestNearlyDiagonalSparsifier(TestCase):
    def test_constructor(self):
        model = SimpleLinear()
        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
        sparsifier.prepare(model, config=None)
        for g in sparsifier.groups:
            assert isinstance(g["module"], nn.Linear)
            # The groups are unordered
            assert g["module_fqn"] in ("seq.0", "seq.1", "seq.2", "linear1", "linear2")

    def test_step(self):
        model = SimpleLinear()
        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
        sparsifier.prepare(model, config=[{"tensor_fqn": "linear1.weight"}])

        for g in sparsifier.groups:
            # Before step
            module = g["module"]
            assert (
                1.0 - module.parametrizations["weight"][0].mask.mean()
            ) == 0  # checking sparsity level is 0

        sparsifier.enable_mask_update = True
        sparsifier.step()
        mask = module.parametrizations["weight"][0].mask
        height, width = mask.shape
        assert torch.all(mask == torch.eye(height, width))

        for g in sparsifier.groups:
            # After step
            module = g["module"]
            assert (
                1.0 - module.parametrizations["weight"][0].mask.mean()
            ) > 0  # checking sparsity level has increased

        # Test if the mask collapses to all zeros if the weights are randomized
        iters_before_collapse = 1000
        for _ in range(iters_before_collapse):
            model.linear1.weight.data = torch.randn(model.linear1.weight.shape)
            sparsifier.step()
        for g in sparsifier.groups:
            # After step
            module = g["module"]
            assert (
                1.0 - module.parametrizations["weight"][0].mask.mean()
            ) > 0  # checking sparsity level did not collapse

    def test_prepare(self):
        model = SimpleLinear()
        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
        sparsifier.prepare(model, config=None)
        for g in sparsifier.groups:
            module = g["module"]
            # Check mask exists
            assert hasattr(module.parametrizations["weight"][0], "mask")
            # Check parametrization exists and is correct
            assert is_parametrized(module, "weight")
            assert type(module.parametrizations.weight[0]) == FakeSparsity

    def test_mask_squash(self):
        model = SimpleLinear()
        sparsifier = NearlyDiagonalSparsifier(nearliness=1)
        sparsifier.prepare(model, config=None)
        sparsifier.step()
        sparsifier.squash_mask()
        for g in sparsifier.groups:
            module = g["module"]
            assert not is_parametrized(module, "weight")
            assert not hasattr(module, "mask")
            weights = module.weight
            height, width = weights.shape
            assert torch.all(
                weights == torch.eye(height, width) * weights
            )  # only diagonal to be present

    def test_sparsity_levels(self):
        nearliness_levels = list(range(-1, 100))
        model = nn.Sequential()

        p = re.compile(r"[-\.\s]")
        for nearliness in nearliness_levels:
            sparsifier = NearlyDiagonalSparsifier(nearliness=1)
            layer_name = f"{nearliness}"
            layer_name = p.sub("_", layer_name)

            layer = nn.Linear(32, 32, bias=False)
            layer.weight = nn.Parameter(torch.ones(32, 32))
            width, height = layer.weight.shape
            model.add_module(layer_name, layer)
            config = {"tensor_fqn": layer_name + ".weight", "nearliness": nearliness}

            sparsifier.prepare(model, [config])
            # should raise a ValueError when nearliness arg is illegal
            if (nearliness > 0 and nearliness % 2 == 0) or (
                nearliness // 2 >= min(width, height)
            ):
                with self.assertRaises(ValueError):
                    sparsifier.step()
            else:
                sparsifier.step()
                sparsifier.squash_mask()
                model.eval()

                layer = getattr(model, layer_name)
                # verify that mask created corresponds to the nearliness
                self._verify_nearliness(layer.weight, nearliness)

    # helper function to verify nearliness of a mask
    def _verify_nearliness(self, mask: torch.Tensor, nearliness: int):
        if nearliness <= 0:
            assert torch.all(mask == torch.zeros(mask.shape[0], mask.shape[1]))
        else:
            height, width = mask.shape
            dist_to_diagonal = nearliness // 2
            for row in range(0, height):
                for col in range(0, width):
                    if abs(row - col) <= dist_to_diagonal:
                        assert mask[row, col] == 1
                    else:
                        assert mask[row, col] == 0
