# Owner(s): ["module: unknown"]
import copy
import logging
import random

import torch
from torch import nn
from torch.ao.pruning._experimental.pruner import (
    BaseStructuredSparsifier,
    FakeStructuredSparsity,
    FPGMPruner,
    LSTMSaliencyPruner,
    SaliencyPruner,
)
from torch.nn.utils import parametrize
from torch.testing._internal.common_pruning import (
    Conv2dActivation,
    Conv2dBias,
    Conv2dPadBias,
    Conv2dPool,
    Conv2dPoolFlatten,
    Conv2dPoolFlattenFunctional,
    LinearActivation,
    LinearActivationFunctional,
    LinearBias,
    LSTMLayerNormLinearModel,
    LSTMLinearModel,
    rows_are_subset,
    SimpleConv2d,
    SimpleLinear,
)
from torch.testing._internal.common_utils import skipIfTorchDynamo, TestCase


logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

DEVICES = {
    torch.device("cpu"),
    torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
}


class SimplePruner(BaseStructuredSparsifier):
    def update_mask(self, module, tensor_name, **kwargs):
        getattr(module.parametrizations, tensor_name)[0].mask[1] = False


class ImplementedPruner(BaseStructuredSparsifier):
    def update_mask(self, module, tensor_name, **kwargs):
        """Prunes 1/3 of the weight output channels, so resulting module has 33.3% pruning"""
        num_rows = len(module.parametrizations[tensor_name][0].mask)
        prune = random.sample(list(range(num_rows)), num_rows // 3)
        module.parametrizations[tensor_name][0].mask[prune] = False


class BottomHalfLSTMPruner(BaseStructuredSparsifier):
    """
    Pruner that will remove the bottom half of the rows.
    This is primarily meant for testing purposes
    """

    def update_mask(self, module, tensor_name, **kwargs):
        for p in getattr(module.parametrizations, tensor_name):
            if isinstance(p, FakeStructuredSparsity):
                mask = p.mask
                masks = torch.split(mask, len(mask) // 4)
                for small in masks:
                    num = len(small)
                    small[num // 2 :] = False
                new_mask = torch.cat(masks)
                mask.data = new_mask.data


class TestSaliencyPruner(TestCase):
    def test_saliency_pruner_update_mask(self):
        """Test that we prune out the row with the lowest saliency (first row)"""
        model = SimpleLinear()
        with torch.no_grad():
            model.linear1.weight = nn.Parameter(
                torch.Tensor([[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]])
            )
        pruning_config = [{"tensor_fqn": "linear1.weight", "sparsity_level": 0.5}]
        pruner = SaliencyPruner({})

        pruner.prepare(model, pruning_config)
        pruner.enable_mask_update = True
        pruner.step()
        pruned_model = pruner.prune()

        expected = torch.Tensor([[3, 3, 3, 3], [4, 4, 4, 4]])
        pruned = pruned_model.linear1.weight

        assert expected.shape == pruned.shape
        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()

    def test_lstm_saliency_pruner_update_mask(self):
        model = LSTMLinearModel(
            input_dim=2,
            hidden_dim=2,
            output_dim=2,
            num_layers=1,
        )

        manual_weights = torch.Tensor(
            [[1, 1], [2, 2], [2, 2], [1, 1], [-1, -1], [-2, -2], [-2, -2], [-1, -1]]
        )

        with torch.no_grad():
            model.lstm.weight_ih_l0 = nn.Parameter(manual_weights)
            model.lstm.weight_hh_l0 = nn.Parameter(torch.Tensor(manual_weights))
            model.lstm.bias_ih_l0 = nn.Parameter(manual_weights[:, 0])
            model.lstm.bias_hh_l0 = nn.Parameter(manual_weights[:, 0])

        config = [
            {"tensor_fqn": "lstm.weight_ih_l0"},
            {"tensor_fqn": "lstm.weight_hh_l0"},
        ]
        lstm_input = torch.ones((1, 2))
        fx_pruner = LSTMSaliencyPruner({"sparsity_level": 0.5})
        fx_pruner.prepare(model, config)
        fx_pruner.enable_mask_update = True
        fx_pruner.step()

        model.eval()
        pruned_model = fx_pruner.prune()
        pruned_model.eval()

        # make sure both models run
        model(lstm_input)
        pruned_model(lstm_input)

        # make sure lowest saliency rows are pruned
        expected = torch.Tensor([[2, 2], [2, 2], [-2, -2], [-2, -2]])
        pruned = model.lstm.weight_ih_l0
        assert expected.shape == pruned.shape
        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()

        expected = torch.Tensor([[2], [2], [-2], [-2]])
        pruned = model.lstm.weight_hh_l0
        assert expected.shape == pruned.shape
        assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()

        expected = torch.Tensor([2, 2, -2, -2])
        for pruned in [model.lstm.bias_ih_l0, model.lstm.bias_hh_l0]:
            assert expected.shape == pruned.shape
            assert torch.isclose(expected, pruned, rtol=1e-05, atol=1e-07).all()


class TestBaseStructuredSparsifier(TestCase):
    def _check_pruner_prepared(self, model, pruner, device):
        for config in pruner.groups:
            module = config["module"]
            assert module.weight.device.type == device.type
            # Check mask exists
            assert config["tensor_fqn"] in pruner.state
            # Check parametrization exists and is correct
            assert parametrize.is_parametrized(module)
            assert hasattr(module, "parametrizations")
            # Assume that this is the 1st/only parametrization
            assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity

    def _check_pruner_valid_before_step(self, model, pruner, device):
        for config in pruner.groups:
            modules = []
            if type(config["module"]) is tuple:
                modules.extend(config["module"])
            else:
                module = config["module"]
                modules.append(module)
            for module in modules:
                assert module.weight.device.type == device.type
                assert module.parametrizations.weight[0].mask.dtype == torch.bool

    def _check_pruner_valid_after_step(self, model, pruner, mask, device):
        for config in pruner.groups:
            modules = []
            if type(config["module"]) is tuple:
                modules.extend(config["module"])
            else:
                module = config["module"]
                modules.append(module)
            for module in modules:
                assert module.weight.device.type == device.type
                total = module.parametrizations.weight[0].mask.numel()
                assert (
                    module.parametrizations.weight[0].mask.count_nonzero()
                    == total - mask
                )

    def _test_constructor_on_device(self, model, device):
        self.assertRaisesRegex(
            TypeError,
            "BaseStructuredSparsifier.*update_mask",
            BaseStructuredSparsifier,
        )
        model1 = copy.deepcopy(model).to(device)
        pruner = SimplePruner(None)
        pruner.prepare(model1, None)
        pruner.enable_mask_update = True
        for g in pruner.groups:
            module = g["module"]
            assert module.weight.device.type == device.type
        assert len(pruner.groups) == 5
        pruner.step()
        # Can instantiate the model with configs
        model2 = copy.deepcopy(model).to(device)
        pruner = SimplePruner({"test": 3})
        pruner.prepare(model2, [{"tensor_fqn": "seq.0.weight"}])
        assert len(pruner.groups) == 1
        assert pruner.groups[0]["module_fqn"] == "seq.0"
        assert "test" in pruner.groups[0]
        assert pruner.groups[0]["test"] == 3

    def test_constructor(self):
        model = SimpleLinear()
        for device in DEVICES:
            self._test_constructor_on_device(model, torch.device(device))

    def _test_prepare_linear_on_device(self, model, device):
        model = copy.deepcopy(model).to(device)
        x = torch.ones(128, 7, device=device)
        pruner = SimplePruner(None)
        pruner.prepare(model, None)
        self._check_pruner_prepared(model, pruner, device)
        assert model(x).shape == (128, 10)

    def test_prepare_linear(self):
        models = [
            SimpleLinear(),
            LinearBias(),
            LinearActivation(),
            LinearActivationFunctional(),
        ]  # without and with bias
        for device in DEVICES:
            for model in models:
                self._test_prepare_linear_on_device(model, torch.device(device))

    def _test_prepare_conv2d_on_device(self, model, expected_shape, config, device):
        x = torch.ones((1, 1, 28, 28), device=device)
        pruner = SimplePruner(None)
        pruner.prepare(model, config)
        self._check_pruner_prepared(model, pruner, device)
        assert model(x).shape == expected_shape

    def test_prepare_conv2d(self):
        models = [
            SimpleConv2d(),
            Conv2dBias(),
            Conv2dActivation(),
            Conv2dPadBias(),
            Conv2dPool(),
        ]
        shapes = [
            (1, 52, 20, 20),
            (1, 52, 18, 18),
            (1, 52, 18, 18),
            (1, 52, 24, 24),
            (1, 52, 3, 3),
        ]
        configs = [None, None, None, None, None]
        for device in DEVICES:
            for model, shape, config in zip(models, shapes, configs):
                model = model.to(device)
                self._test_prepare_conv2d_on_device(
                    model, shape, config, torch.device(device)
                )

    def _test_step_linear_on_device(self, model, device):
        model = model.to(device)
        x = torch.ones(7, 7, device=device)
        pruner = SimplePruner(None)
        pruner.prepare(model, None)
        pruner.enable_mask_update = True
        self._check_pruner_valid_before_step(model, pruner, device)
        pruner.step()
        self._check_pruner_valid_after_step(model, pruner, 1, device)

    def test_step_linear(self):
        models = [
            SimpleLinear(),
            LinearBias(),
            LinearActivation(),
            LinearActivationFunctional(),
        ]
        for device in DEVICES:
            for model in models:
                self._test_step_linear_on_device(model, torch.device(device))

    def _test_step_conv2d_on_device(self, model, expected_shape, config, device):
        model = model.to(device)
        x = torch.ones((1, 1, 28, 28), device=device)
        pruner = SimplePruner(None)
        pruner.prepare(model, config)
        pruner.enable_mask_update = True
        self._check_pruner_valid_before_step(model, pruner, device)
        pruner.step()
        self._check_pruner_valid_after_step(model, pruner, 1, device)
        assert model(x).shape == expected_shape

    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
    def test_step_conv2d(self):
        models = [
            SimpleConv2d(),
            Conv2dBias(),
            Conv2dActivation(),
            Conv2dPadBias(),
            Conv2dPool(),
        ]
        shapes = [
            (1, 52, 20, 20),
            (1, 52, 18, 18),
            (1, 52, 18, 18),
            (1, 52, 24, 24),
            (1, 52, 3, 3),
        ]
        configs = [None, None, None, None, None]
        for device in DEVICES:
            for model, shape, config in zip(models, shapes, configs):
                self._test_step_conv2d_on_device(
                    model, shape, config, torch.device(device)
                )

    def _check_pruner_pruned(self, model, pruner, device):
        for config in pruner.groups:
            module = config["module"]
            assert not hasattr(module, "parametrizations")
            assert not hasattr(module, "mask")

    def _test_linear_on_device(
        self, model, config, expected_shape, device, also_prune_bias
    ):
        model = model.to(device)
        model.eval()
        num_original_params = sum(p.numel() for p in model.parameters())
        x = torch.ones(128, 7, device=device)

        pruner = ImplementedPruner({"prune_bias": also_prune_bias})
        pruner.prepare(model, config)
        pruner.enable_mask_update = True
        pruner.step()

        y_expected = model(x)

        assert y_expected.shape == (128, 10)
        self._check_pruner_prepared(model, pruner, device)

        # Pruning step
        pruned = pruner.prune()
        y_pruned = pruned(x)
        num_pruned_params = sum(p.numel() for p in pruned.parameters())

        assert y_pruned.shape == expected_shape
        self._check_pruner_pruned(model, pruner, device)
        if y_pruned.shape == y_expected.shape:
            assert torch.isclose(y_expected, y_pruned, rtol=1e-05, atol=1e-07).all()
            assert num_pruned_params < num_original_params

    def test_prune_linear_linear(self):
        r"""test pruning linear-> linear modules"""
        configs, shapes = [], []
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
                {"tensor_fqn": "seq.2.weight"},
            ]
        )
        shapes.append((128, 10))

        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
                {"tensor_fqn": "seq.2.weight"},
                {"tensor_fqn": "linear1.weight"},
            ]
        )
        shapes.append((128, 10))

        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.2.weight"},
            ]
        )
        shapes.append((128, 10))
        for device in DEVICES:
            for also_prune_bias in [True, False]:
                for config, shape in zip(configs, shapes):
                    self._test_linear_on_device(
                        SimpleLinear(),
                        config,
                        shape,
                        torch.device(device),
                        also_prune_bias,
                    )

    def test_prune_linear_bias_linear(self):
        # linear(bias) -> linear(no bias)
        configs, shapes = [], []
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
            ]
        )
        shapes.append((128, 10))

        # linear(bias) -> linear(bias)
        configs.append(
            [
                {"tensor_fqn": "seq.2.weight"},
                {"tensor_fqn": "seq.3.weight"},
            ]
        )
        shapes.append((128, 10))

        # linear(no bias) -> linear(bias)
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
                {"tensor_fqn": "seq.2.weight"},
            ]
        )
        shapes.append((128, 10))

        for device in DEVICES:
            for also_prune_bias in [True, False]:
                for config, shape in zip(configs, shapes):
                    self._test_linear_on_device(
                        LinearBias(),
                        config,
                        shape,
                        torch.device(device),
                        also_prune_bias,
                    )

    def test_prune_linear_activation_linear(self):
        config = [
            {"tensor_fqn": "seq.0.weight"},
            {"tensor_fqn": "seq.2.weight"},
            {"tensor_fqn": "seq.4.weight"},
            {"tensor_fqn": "linear1.weight"},
        ]
        shape = (128, 10)

        for device in DEVICES:
            for also_prune_bias in [True, False]:
                # test version with nn.Modules
                self._test_linear_on_device(
                    LinearActivation(),
                    config,
                    shape,
                    torch.device(device),
                    also_prune_bias,
                )
                # test functional version
                self._test_linear_on_device(
                    LinearActivationFunctional(),
                    config,
                    shape,
                    torch.device(device),
                    also_prune_bias,
                )

    def _test_conv2d_on_device(
        self, model, config, x, expected_shape, device, also_prune_bias
    ):
        model = model.to(device)
        num_original_params = sum(p.numel() for p in model.parameters())
        model.eval()

        pruner = ImplementedPruner({"prune_bias": also_prune_bias})
        pruner.prepare(model, config)
        pruner.enable_mask_update = True
        pruner.step()

        y_expected = model(x)
        assert y_expected.shape == expected_shape

        self._check_pruner_prepared(model, pruner, device)

        # Fusion step
        pruned = pruner.prune()
        y_pruned = pruned(x)
        num_pruned_params = sum(p.numel() for p in pruned.parameters())

        assert y_pruned.shape == expected_shape
        self._check_pruner_pruned(model, pruner, device)
        if y_pruned.shape == y_expected.shape:
            # TODO This rtol is a little high, need to double check if something specific is causing this to fail
            assert torch.isclose(
                y_expected,
                y_pruned,
                rtol=1e-3,
                atol=1e-3,
            ).all(), f"fail for {type(model)}"
            # only time this should be equal is when all layers have padding and we can't prune
            assert num_pruned_params <= num_original_params

    def test_prune_conv2d_conv2d(self):
        configs, shapes = [], []
        # all within sequential blocks
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
            ]
        )
        shapes.append((1, 52, 20, 20))
        # prune across sequential blocks
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
                {"tensor_fqn": "conv2d1.weight"},
            ]
        )
        shapes.append((1, 52, 20, 20))

        for device in DEVICES:
            x = torch.ones((1, 1, 28, 28), device=device)
            for also_prune_bias in [True, False]:
                for config, shape in zip(configs, shapes):
                    self._test_conv2d_on_device(
                        SimpleConv2d(),
                        config,
                        x,
                        shape,
                        torch.device(device),
                        also_prune_bias,
                    )

    def test_prune_conv2d_bias_conv2d(self):
        # Conv2d with Bias and no Activation
        configs, shapes = [], []
        # conv2d(bias) -> conv2d(bias)
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        # conv2d(no bias) -> conv2d(bias)
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
                {"tensor_fqn": "conv2d1.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        # conv2d(bias) -> conv2d(no bias)
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.1.weight"},
                {"tensor_fqn": "seq.2.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        for device in DEVICES:
            x = torch.ones((1, 1, 28, 28), device=device)
            for also_prune_bias in [True, False]:
                for config, shape in zip(configs, shapes):
                    self._test_conv2d_on_device(
                        Conv2dBias(),
                        config,
                        x,
                        shape,
                        torch.device(device),
                        also_prune_bias,
                    )

    def test_prune_conv2d_activation_conv2d(self):
        # Conv2d with Activation and no Bias
        configs, shapes = [], []

        # conv2d(no bias) -> activation -> conv2d(no bias)
        configs.append(
            [
                {"tensor_fqn": "seq.4.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        # conv2d(bias) -> activation -> conv2d(bias)
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
                {"tensor_fqn": "seq.2.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        # conv2d(bias) -> activation -> conv2d(no bias)
        configs.append(
            [
                {"tensor_fqn": "seq.2.weight"},
                {"tensor_fqn": "seq.4.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        # conv2d(no bias) -> activation -> conv2d(bias)
        configs.append(
            [
                {"tensor_fqn": "conv2d1.weight"},
            ]
        )
        shapes.append((1, 52, 18, 18))

        for device in DEVICES:
            x = torch.ones((1, 1, 28, 28), device=device)
            for also_prune_bias in [True, False]:
                for config, shape in zip(configs, shapes):
                    self._test_conv2d_on_device(
                        Conv2dActivation(),
                        config,
                        x,
                        shape,
                        torch.device(device),
                        also_prune_bias,
                    )

    def test_prune_conv2d_padding_conv2d(self):
        # Conv2d with Padded layers after Bias layers
        configs, shapes = [], []

        # conv(padded, bias) -> conv(padded, bias)
        configs.append(
            [
                {"tensor_fqn": "seq.4.weight"},
            ]
        )
        shapes.append((1, 52, 24, 24))

        # conv(no bias, no pad) -> conv(padded, bias)
        configs.append(
            [
                {"tensor_fqn": "seq.2.weight"},
            ]
        )
        shapes.append((1, 52, 24, 24))

        # conv(padded, bias) -> conv ( no bias ,no pad)
        configs.append(
            [
                {"tensor_fqn": "seq.0.weight"},
            ]
        )
        shapes.append((1, 52, 24, 24))
        # conv(pad, bias) -> conv(no pad, bias)
        configs.append(
            [
                {"tensor_fqn": "seq.6.weight"},
            ]
        )
        shapes.append((1, 52, 24, 24))
        # conv(no pad, bias) -> conv(pad, bias)
        configs.append(
            [
                {"tensor_fqn": "seq.8.weight"},
            ]
        )
        shapes.append((1, 52, 24, 24))

        for device in DEVICES:
            x = torch.ones((1, 1, 28, 28), device=device)
            for also_prune_bias in [True, False]:
                for config, shape in zip(configs, shapes):
                    self._test_conv2d_on_device(
                        Conv2dPadBias(),
                        config,
                        x,
                        shape,
                        torch.device(device),
                        also_prune_bias,
                    )

    def test_prune_conv2d_pool_conv2d(self):
        # Conv2d with Pooling layers
        config = [
            {"tensor_fqn": "seq.0.weight"},
            {"tensor_fqn": "seq.3.weight"},
            {"tensor_fqn": "conv2d1.weight"},
            {"tensor_fqn": "conv2d2.weight"},
        ]
        shape = (1, 52, 3, 3)

        for device in DEVICES:
            x = torch.ones((1, 1, 28, 28), device=device)
            for also_prune_bias in [True, False]:
                self._test_conv2d_on_device(
                    Conv2dPool(),
                    config,
                    x,
                    shape,
                    torch.device(device),
                    also_prune_bias,
                )

    @skipIfTorchDynamo("TorchDynamo fails with unknown reason")
    def test_complex_conv2d(self):
        """Test fusion for models that contain Conv2d & Linear modules.
        Currently supports: Conv2d-Pool2d-Flatten-Linear, Skip-add"""
        config = [
            {"tensor_fqn": "seq.0.weight"},
            {"tensor_fqn": "seq.3.weight"},
            {"tensor_fqn": "conv2d1.weight"},
            {"tensor_fqn": "conv2d2.weight"},
        ]
        shape = (1, 13)

        for device in DEVICES:
            x = torch.ones((1, 1, 28, 28), device=device)
            for also_prune_bias in [True, False]:
                self._test_conv2d_on_device(
                    Conv2dPoolFlattenFunctional(),
                    config,
                    x,
                    shape,
                    torch.device(device),
                    also_prune_bias,
                )
                self._test_conv2d_on_device(
                    Conv2dPoolFlatten(),
                    config,
                    x,
                    shape,
                    torch.device(device),
                    also_prune_bias,
                )

    def test_prune_lstm_linear_multiple_layer(self):
        """
        Test fusion support for LSTM(multi-layer) -> Linear
        """
        model = LSTMLinearModel(
            input_dim=8,
            hidden_dim=8,
            output_dim=8,
            num_layers=2,
        )

        config = [
            {"tensor_fqn": "lstm.weight_ih_l0"},
            {"tensor_fqn": "lstm.weight_hh_l0"},
            {"tensor_fqn": "lstm.weight_ih_l1"},
            {"tensor_fqn": "lstm.weight_hh_l1"},
        ]

        lstm_input = torch.ones((1, 8))
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
        fx_pruner.prepare(model, config)

        fx_pruner.enable_mask_update = True
        fx_pruner.step()

        model.eval()
        _, _ = model(lstm_input)
        pruned_model = fx_pruner.prune()
        pruned_model.eval()
        _, _ = pruned_model(lstm_input)

        expected_params = dict(model.named_parameters())
        for name, param in model.named_parameters():
            assert name in expected_params
            # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
            # Instead we check that the weights of the new LSTM are a subset of the weights of
            # the old LSTM
            assert rows_are_subset(param, expected_params[name])
            del expected_params[name]

        # assert we haven't deleted any keys
        assert len(expected_params) == 0

    def test_prune_lstm_linear_single_layer(self):
        """
        Test fusion support for LSTM (single-layer) -> Linear
        """
        model = LSTMLinearModel(
            input_dim=8,
            hidden_dim=8,
            output_dim=8,
            num_layers=1,
        )

        config = [
            {"tensor_fqn": "lstm.weight_ih_l0"},
            {"tensor_fqn": "lstm.weight_hh_l0"},
        ]

        lstm_input = torch.ones((1, 8))
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
        fx_pruner.prepare(model, config)
        fx_pruner.enable_mask_update = True
        fx_pruner.step()
        model.eval()

        out_expected, lstm_out_expected = model(lstm_input)
        pruned_model = fx_pruner.prune()
        pruned_model.eval()
        out_pruned, lstm_out_pruned = pruned_model(lstm_input)
        r, c = lstm_out_expected.size()

        # We cannot check that y_expected == y_pruned as usual because
        # zeros vs. missing elements yield different numerical results.
        # Instead that we check that the pruned elements are the first half of the results
        # since we are using a BottomHalfLSTMPruner
        assert torch.isclose(
            lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
        ).all()
        # also check that output of linear is the same shape, this means we've resized
        # linear columns correctly.
        assert out_expected.shape == out_pruned.shape

    def test_prune_lstm_layernorm_linear_multiple_layer(self):
        """
        Test fusion support for LSTM(multi-layer) -> Linear
        """
        model = LSTMLayerNormLinearModel(
            input_dim=8,
            output_dim=8,
            hidden_dim=8,
            num_layers=2,
        )

        config = [
            {"tensor_fqn": "lstm.weight_ih_l0"},
            {"tensor_fqn": "lstm.weight_hh_l0"},
            {"tensor_fqn": "lstm.weight_ih_l1"},
            {"tensor_fqn": "lstm.weight_hh_l1"},
        ]

        lstm_input = torch.ones((1, 8))
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
        fx_pruner.prepare(model, config)

        fx_pruner.enable_mask_update = True
        fx_pruner.step()

        model.eval()
        _, _ = model(lstm_input)
        pruned_model = fx_pruner.prune()
        pruned_model.eval()
        _, _ = pruned_model(lstm_input)

        expected_params = dict(model.named_parameters())
        for name, param in model.named_parameters():
            assert name in expected_params
            # We cannot compare y_expected == y_pruned, as the 0 elements mess up the numerics
            # Instead we check that the weights of the new LSTM are a subset of the weights of
            # the old LSTM
            assert rows_are_subset(param, expected_params[name])
            del expected_params[name]

        # assert we haven't deleted any keys
        assert len(expected_params) == 0

    def test_prune_lstm_layernorm_linear_single_layer(self):
        """
        Test fusion support for LSTM (single-layer) -> Linear
        """
        model = LSTMLinearModel(
            input_dim=8,
            hidden_dim=8,
            output_dim=8,
            num_layers=1,
        )

        config = [
            {"tensor_fqn": "lstm.weight_ih_l0"},
            {"tensor_fqn": "lstm.weight_hh_l0"},
        ]

        lstm_input = torch.ones((1, 8))
        fx_pruner = BottomHalfLSTMPruner({"sparsity_level": 0.5})
        fx_pruner.prepare(model, config)
        fx_pruner.enable_mask_update = True
        fx_pruner.step()
        model.eval()

        out_expected, lstm_out_expected = model(lstm_input)
        pruned_model = fx_pruner.prune()
        pruned_model.eval()
        out_pruned, lstm_out_pruned = pruned_model(lstm_input)
        r, c = lstm_out_expected.size()

        # We cannot check that y_expected == y_pruned as usual because
        # zeros vs. missing elements yield different numerical results.
        # Instead that we check that the pruned elements are the first half of the results
        # since we are using a BottomHalfLSTMPruner
        assert torch.isclose(
            lstm_out_expected[:, : c // 2], lstm_out_pruned, rtol=1e-05, atol=1e-07
        ).all()
        # also check that output of linear is the same shape, this means we've resized
        # linear columns correctly.
        assert out_expected.shape == out_pruned.shape


class TestFPGMPruner(TestCase):
    """
    Test case for the implementation of paper:
    `Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration <https://arxiv.org/abs/1811.00250>`_.
    """

    class SimpleConvFPGM(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv2d1 = nn.Conv2d(
                in_channels=1, out_channels=3, kernel_size=3, padding=1, bias=False
            )
            # Manually set the filter weights for demonstration purposes
            """
            Three filters' weight are manually set to values 3.0, 2.0, and 0.1.
            Different from the norm-based decision that prunes filter with value 0.1,
            FPGM will prune the one with value 2.0.
            """
            weights = torch.tensor([3.0, 2.0, 0.1])  # Weight weights for each filter
            weights = weights[:, None, None, None]  # broadcasting
            self.conv2d1.weight.data.copy_(
                torch.ones(self.conv2d1.weight.shape) * weights
            )

            # Second Convolutional Layer
            self.conv2d2 = nn.Conv2d(
                in_channels=3, out_channels=4, kernel_size=3, padding=1, bias=False
            )
            weights = torch.tensor([6.0, 7.0, 0.4, 0.5])
            weights = weights[:, None, None, None]
            self.conv2d2.weight.data.copy_(
                torch.ones(self.conv2d2.weight.shape) * weights
            )

        def forward(self, x):
            x = self.conv2d1(x)
            x = self.conv2d2(x)
            return x

    def test_compute_distance(self, device="cpu"):
        """Test the distance computation function"""
        model = TestFPGMPruner.SimpleConvFPGM().to(device)
        pruner = FPGMPruner(0.3)
        dist_conv1 = pruner._compute_distance(model.conv2d1.weight)

        # compute the distance matrix using torch.cdist
        flattened_filters = torch.Tensor(
            [
                [
                    3.0000,
                    3.0000,
                    3.0000,
                    3.0000,
                    3.0000,
                    3.0000,
                    3.0000,
                    3.0000,
                    3.0000,
                ],
                [
                    2.0000,
                    2.0000,
                    2.0000,
                    2.0000,
                    2.0000,
                    2.0000,
                    2.0000,
                    2.0000,
                    2.0000,
                ],
                [
                    0.1000,
                    0.1000,
                    0.1000,
                    0.1000,
                    0.1000,
                    0.1000,
                    0.1000,
                    0.1000,
                    0.1000,
                ],
            ]
        )

        """
        Expected distance matrix should have the following values:
            [0.0000, 3.0000, 8.7000],
            [3.0000, 0.0000, 5.7000],
            [8.7000, 5.7000, 0.0000],
        the distance should therefore be:
            [11.7000, 8.7000, 14.4000]
        """
        expected_dist_matrix_conv1 = torch.cdist(
            flattened_filters, flattened_filters, p=2
        )
        expected_dist_conv1 = torch.sum(torch.abs(expected_dist_matrix_conv1), 1)
        assert torch.isclose(
            dist_conv1, expected_dist_conv1, rtol=1e-05, atol=1e-07
        ).all()

    def _test_update_mask_on_single_layer(self, expected_conv1, device):
        """Test that pruning is conducted based on the pair-wise distance measurement instead of absolute norm value"""
        # test pruning with one layer of conv2d
        model = TestFPGMPruner.SimpleConvFPGM().to(device)
        x = torch.ones((1, 1, 32, 32), device=device)
        pruner = FPGMPruner(0.3)
        config = [{"tensor_fqn": "conv2d1.weight"}]
        pruner.prepare(model, config)
        pruner.enable_mask_update = True
        pruner.step()
        assert (
            pruner.groups[0]["module"].parametrizations.weight[0].mask[-1].item()
            is not False
        ), "do not prune the least-norm filter"

        # fusion step
        pruned_model = pruner.prune()

        pruned_y = pruned_model(x)
        # assert shapes
        expected_conv1 = expected_conv1.to(device)
        assert pruned_y.shape == (1, 4, 32, 32)
        assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
        assert pruned_model.conv2d2.weight.shape == (
            4,
            2,
            3,
            3,
        ), "conv2d2 should have input channel pruned"
        # assert value
        assert torch.isclose(
            pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
        ).all()

    def _test_update_mask_on_multiple_layer(
        self, expected_conv1, expected_conv2, device
    ):
        # the second setting
        model = TestFPGMPruner.SimpleConvFPGM().to(device)
        x = torch.ones((1, 1, 32, 32), device=device)
        pruner = FPGMPruner(0.3)
        config = [
            {"tensor_fqn": "conv2d1.weight"},
            {"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5},
        ]
        pruner.prepare(model, config)
        pruner.enable_mask_update = True
        pruner.step()
        # Get the masks for the two least-norm filters
        mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1]
        mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2]
        # Check if either of the least-norm filters is not pruned
        assert (
            mask1.item() is not False or mask2.item() is not False
        ), "Do not prune all least-norm filters"

        # fusion step
        pruned_model = pruner.prune()
        pruned_y = pruned_model(x)
        # assert shapes
        expected_conv1 = expected_conv1.to(device)
        expected_conv2 = expected_conv2.to(device)
        assert pruned_y.shape == (1, 2, 32, 32)
        assert pruned_model.conv2d1.weight.shape == expected_conv1.shape
        assert pruned_model.conv2d2.weight.shape == expected_conv2.shape
        # assert values
        assert torch.isclose(
            pruned_model.conv2d1.weight, expected_conv1, rtol=1e-05, atol=1e-07
        ).all()
        assert torch.isclose(
            pruned_model.conv2d2.weight, expected_conv2, rtol=1e-05, atol=1e-07
        ).all()

    def test_update_mask(self):
        weights = torch.tensor([3.0, 0.1])
        expected_conv1 = torch.ones((2, 1, 3, 3)) * weights[:, None, None, None]

        weights = torch.tensor([7.0, 0.4])
        expected_conv2 = torch.ones((2, 2, 3, 3)) * weights[:, None, None, None]

        for device in DEVICES:
            self._test_update_mask_on_single_layer(expected_conv1, device)
            self._test_update_mask_on_multiple_layer(
                expected_conv1, expected_conv2, device
            )
