# Owner(s): ["module: unknown"]


import logging

import torch
import torch.ao.quantization as tq
from torch import nn
from torch.ao import pruning
from torch.ao.pruning import fqn_to_module
from torch.ao.quantization.quantize_fx import (
    convert_fx,
    convert_to_reference_fx,
    prepare_fx,
    prepare_qat_fx,
)
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

sparse_defaults = {
    "sparsity_level": 0.8,
    "sparse_block_shape": (1, 4),
    "zeros_per_block": 4,
}


def _get_model_and_sparsifier_and_sparse_config(qconfig=None):
    model = nn.Sequential(
        nn.Linear(4, 4),  # 0
        nn.ReLU(),
        nn.Linear(4, 4),  # 2
        nn.ReLU(),
        tq.QuantStub(),
        nn.Linear(4, 4),  # 5
        nn.ReLU(),
        tq.DeQuantStub(),
    )
    if qconfig:
        model[4].qconfig = qconfig
        model[5].qconfig = qconfig

    sparsifier = pruning.WeightNormSparsifier(**sparse_defaults)

    sparse_config = [
        {
            "tensor_fqn": "5.weight",
            "sparsity_level": 0.7,
            "sparse_block_shape": (1, 4),
            "zeros_per_block": 4,
        },
        {"tensor_fqn": "0.weight"},
    ]
    return model, sparsifier, sparse_config


def _squash_mask_calibrate_and_convert(model, sparsifier, input):
    sparsifier.step()
    sparsifier.squash_mask()
    model(input)
    tq.convert(model, inplace=True)


def _calculate_sparsity(tensor):
    return ((tensor == 0).sum() / tensor.numel()).item()


# This series of tests are to check the composability goals for sparsity and quantization. Namely
# that performing quantization and sparsity model manipulations in various orderings
# does not cause problems
class TestComposability(TestCase):
    # This test checks whether performing quantization prepare before sparse prepare
    # causes any issues and verifies that the correct observers are inserted and that
    # the quantized model works as expected
    def test_q_prep_before_s_prep(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qconfig("fbgemm")
        )

        tq.prepare(mod, inplace=True)
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))
        # check that correct observers were inserted
        self.assertTrue(hasattr(mod[5], "activation_post_process"))

        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

    # This test checks whether performing sparsity prepare before quantization prepare
    # causes any issues. In particular, previous quantization flow was unable to match
    # the post sparse prepare module names (adding parametrizations changes the module class names)
    # which would result in those parametrized modules not being quantized. This test verifies that
    # the fix for this was successful.
    def test_s_prep_before_q_prep(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qconfig("fbgemm")
        )

        sparsifier.prepare(mod, config=sparse_config)
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))

        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

    # if the sparsified modules have not undergone the final squash mask operation, its possible
    # that the problem outlined in test_s_prep_before_q_prep would occur. This test verifies
    # both that the fix to the convert flow avoids this issue and that the resulting quantized
    # module uses the sparse version of the weight value.
    def test_convert_without_squash_mask(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qconfig("fbgemm")
        )

        sparsifier.prepare(mod, config=sparse_config)
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(mod[5].weight)
        mod(torch.randn(1, 4, 4, 4))
        tq.convert(mod, inplace=True)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    # This tests whether performing sparse prepare before fusion causes any issues. The
    # worry was that the link created between the sparsifier and the modules that need to
    # be sparsified would be broken.
    def test_s_prep_before_fusion(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qconfig("fbgemm")
        )
        sparsifier.prepare(mod, config=sparse_config)
        tq.fuse_modules(mod, [["5", "6"]], inplace=True)
        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare or fusion
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5][0], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

    # This tests whether performing fusion before sparse prepare causes and issues. The
    # main worry was that the links to the modules in the sparse config would be broken by fusion.
    def test_fusion_before_s_prep(self):
        (
            mod,
            sparsifier,
            _,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qconfig("fbgemm")
        )
        tq.fuse_modules(mod, [["5", "6"]], inplace=True)

        # its absolutely broken by fusion but will still work if you put the correct fqn in
        sparse_config = [
            {
                "tensor_fqn": "5.0.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {"tensor_fqn": "0.weight"},
        ]

        sparsifier.prepare(mod, config=sparse_config)
        mod[5].qconfig = tq.get_default_qconfig("fbgemm")
        tq.prepare(mod, inplace=True)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5][0], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(mod[5][0].weight)
        mod(torch.randn(1, 4, 4, 4))
        tq.convert(mod, inplace=True)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.intrinsic.quantized.LinearReLU))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    # This tests whether performing sparse prepare before qat prepare causes issues.
    # The primary worries were that qat_prep wouldn't recognize the parametrized
    # modules and that the convert step for qat would remove the parametrizations
    # from the modules.
    def test_s_prep_before_qat_prep(self):
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qat_qconfig("fbgemm")
        )
        sparsifier.prepare(mod, config=sparse_config)
        tq.prepare_qat(mod, inplace=True)
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))
        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))
        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    # This tests whether performing qat prepare before sparse prepare causes issues.
    def test_qat_prep_before_s_prep(self):
        mod, sparsifier, _ = _get_model_and_sparsifier_and_sparse_config(
            tq.get_default_qat_qconfig("fbgemm")
        )
        tq.prepare_qat(mod, inplace=True)

        # need to setup sparse_config on new modules
        sparse_config = [
            {
                "tensor_fqn": "5.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {"tensor_fqn": "0.weight"},
        ]
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added and
        # that none were lost during qat prepare
        self.assertTrue(hasattr(mod[0], "parametrizations"))
        self.assertTrue(hasattr(mod[5], "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(hasattr(mod[5], "activation_post_process"))
        self.assertTrue(isinstance(mod[5], torch.ao.nn.qat.Linear))

        _squash_mask_calibrate_and_convert(mod, sparsifier, torch.randn(1, 4, 4, 4))

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(isinstance(mod[5], torch.ao.nn.quantized.Linear))
        self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(mod[5]._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])


def _module_has_activation_post_process(model, fqn_of_module):
    for node in model.graph.nodes:
        # look for an observer whose arg is the target module
        if "activation_post_process" in node.name:
            if node.args[0].target == fqn_of_module:
                return True
    return False


class TestFxComposability(TestCase):
    r"""This series of tests checks that various steps of the quantization and sparsity flow
    compose cleanly despite variation in sequencing.
    """

    def test_q_prep_fx_before_s_prep(self):
        r"""
        This test checks that the ordering of prepare_fx -> sparse prepare -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask between sparse prepare and convert_fx. This also tests the
        automatic fusion that occurs during prepare_fx.
        """
        (
            mod,
            sparsifier,
            _,
        ) = _get_model_and_sparsifier_and_sparse_config()

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = (
            tq.QConfigMapping()
            .set_module_name("4", qconfig)
            .set_module_name("5", qconfig)
        )

        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # its absolutely broken by auto fusion in fx
        # but will still work if you put the correct fqn in
        sparse_config = [
            {
                "tensor_fqn": "5.0.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {"tensor_fqn": "0.0.weight"},
        ]
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(
            isinstance(
                fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
            )
        )
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    def test_q_prep_fx_s_prep_ref_conv(self):
        r"""
        This checks that the ordering: prepare_fx -> sparse prepare -> convert_to_reference_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_to_reference_fx.
        """
        (
            mod,
            sparsifier,
            _,
        ) = _get_model_and_sparsifier_and_sparse_config()

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = (
            tq.QConfigMapping()
            .set_module_name("4", qconfig)
            .set_module_name("5", qconfig)
        )

        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # its absolutely broken by auto fusion in fx
        # but will still work if you put the correct fqn in
        sparse_config = [
            {
                "tensor_fqn": "5.0.weight",
                "sparsity_level": 0.7,
                "sparse_block_shape": (1, 4),
                "zeros_per_block": 4,
            },
            {"tensor_fqn": "0.0.weight"},
        ]
        sparsifier.prepare(mod, config=sparse_config)

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_to_reference_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(
            isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
        )
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
        self.assertTrue(
            isinstance(
                fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
            )
        )

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    def test_s_prep_before_q_prep_fx(self):
        r"""
        This test checks that the ordering of sparse prepare -> prepare_fx -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_fx.
        """
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config()
        sparsifier.prepare(mod, config=sparse_config)

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = (
            tq.QConfigMapping()
            .set_module_name("4", qconfig)
            .set_module_name("5", qconfig)
        )
        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(
            isinstance(
                fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
            )
        )
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    def test_s_prep_before_qat_prep_fx(self):
        r"""
        This test checks that the ordering of sparse prepare -> prepare_qat_fx -> convert_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_fx.
        """
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config()
        sparsifier.prepare(mod, config=sparse_config)

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qat_qconfig("fbgemm")
        qconfig_mapping = (
            tq.QConfigMapping()
            .set_module_name("4", qconfig)
            .set_module_name("5", qconfig)
        )
        mod = prepare_qat_fx(mod, qconfig_mapping, (example,))

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5"), "parametrizations"))
        self.assertTrue(
            isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.qat.LinearReLU)
        )

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.weight"))
        mod(example)
        mod = convert_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(
            isinstance(
                fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.quantized.LinearReLU
            )
        )
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5")._weight_bias()[0])
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])

    def test_s_prep_q_prep_fx_ref(self):
        r"""
        This checks that the ordering: sparse prepare -> prepare_fx -> convert_to_reference_fx
        compose cleanly without issue and that the final result is sparsified without
        having to call squash mask before convert_to_reference_fx.
        """
        (
            mod,
            sparsifier,
            sparse_config,
        ) = _get_model_and_sparsifier_and_sparse_config()
        sparsifier.prepare(mod, config=sparse_config)

        example = torch.randn(1, 4, 4, 4)
        qconfig = tq.get_default_qconfig("fbgemm")
        qconfig_mapping = (
            tq.QConfigMapping()
            .set_module_name("4", qconfig)
            .set_module_name("5", qconfig)
        )
        mod = prepare_fx(mod, qconfig_mapping, (example,))

        # check that correct modules had parametrizations added and
        # that none were lost during prepare
        self.assertTrue(hasattr(fqn_to_module(mod, "0.0"), "parametrizations"))
        self.assertTrue(hasattr(fqn_to_module(mod, "5.0"), "parametrizations"))

        # check that correct observers were inserted and that matching
        # occurred successfully
        self.assertTrue(_module_has_activation_post_process(mod, "5"))
        sparsifier.step()
        sparsity_level = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        mod(example)
        mod = convert_to_reference_fx(mod)

        # check that final module is the expected quantized module and that the model runs
        self.assertTrue(
            isinstance(fqn_to_module(mod, "5"), torch.ao.nn.intrinsic.LinearReLU)
        )
        self.assertEqual(mod(example).shape, torch.Size([1, 4, 4, 4]))
        self.assertTrue(
            isinstance(
                fqn_to_module(mod, "5.0"), torch.ao.nn.quantized.reference.Linear
            )
        )

        # check that module was actually sparsified
        cur_sparsity = _calculate_sparsity(fqn_to_module(mod, "5.0.weight"))
        self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
        self.assertGreaterAlmostEqual(
            sparsity_level, sparse_config[0]["sparsity_level"]
        )
        self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
