# Owner(s): ["module: unknown"]


import logging

import torch
from torch.ao.pruning.sparsifier.utils import (
    fqn_to_module,
    get_arg_info_from_tensor_fqn,
    module_to_fqn,
)
from torch.testing._internal.common_quantization import (
    ConvBnReLUModel,
    ConvModel,
    FunctionalLinear,
    LinearAddModel,
    ManualEmbeddingBagLinear,
    SingleLayerLinearModel,
    TwoLayerLinearModel,
)
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)

model_list = [
    ConvModel,
    SingleLayerLinearModel,
    TwoLayerLinearModel,
    LinearAddModel,
    ConvBnReLUModel,
    ManualEmbeddingBagLinear,
    FunctionalLinear,
]


class TestSparsityUtilFunctions(TestCase):
    def test_module_to_fqn(self):
        """
        Tests that module_to_fqn works as expected when compared to known good
        module.get_submodule(fqn) function
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                fqn = module_to_fqn(model, module)
                check_module = model.get_submodule(fqn)
                self.assertEqual(module, check_module)

    def test_module_to_fqn_fail(self):
        """
        Tests that module_to_fqn returns None when an fqn that doesn't
        correspond to a path to a node/tensor is given
        """
        for model_class in model_list:
            model = model_class()
            fqn = module_to_fqn(model, torch.nn.Linear(3, 3))
            self.assertEqual(fqn, None)

    def test_module_to_fqn_root(self):
        """
        Tests that module_to_fqn returns '' when model and target module are the same
        """
        for model_class in model_list:
            model = model_class()
            fqn = module_to_fqn(model, model)
            self.assertEqual(fqn, "")

    def test_fqn_to_module(self):
        """
        Tests that fqn_to_module operates as inverse
        of module_to_fqn
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                fqn = module_to_fqn(model, module)
                check_module = fqn_to_module(model, fqn)
                self.assertEqual(module, check_module)

    def test_fqn_to_module_fail(self):
        """
        Tests that fqn_to_module returns None when it tries to
        find an fqn of a module outside the model
        """
        for model_class in model_list:
            model = model_class()
            fqn = "foo.bar.baz"
            check_module = fqn_to_module(model, fqn)
            self.assertEqual(check_module, None)

    def test_fqn_to_module_for_tensors(self):
        """
        Tests that fqn_to_module works for tensors, actually all parameters
        of the model. This is tested by identifying a module with a tensor,
        and generating the tensor_fqn using module_to_fqn on the module +
        the name of the tensor.
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                module_fqn = module_to_fqn(model, module)
                for tensor_name, tensor in module.named_parameters(recurse=False):
                    tensor_fqn = (  # string manip to handle tensors on root
                        module_fqn + ("." if module_fqn != "" else "") + tensor_name
                    )
                    check_tensor = fqn_to_module(model, tensor_fqn)
                    self.assertEqual(tensor, check_tensor)

    def test_get_arg_info_from_tensor_fqn(self):
        """
        Tests that get_arg_info_from_tensor_fqn works for all parameters of the model.
        Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and
        then compares with known (parent) module and tensor_name as well as module_fqn
        from module_to_fqn.
        """
        for model_class in model_list:
            model = model_class()
            list_of_modules = [m for _, m in model.named_modules()] + [model]
            for module in list_of_modules:
                module_fqn = module_to_fqn(model, module)
                for tensor_name, tensor in module.named_parameters(recurse=False):
                    tensor_fqn = (
                        module_fqn + ("." if module_fqn != "" else "") + tensor_name
                    )
                    arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
                    self.assertEqual(arg_info["module"], module)
                    self.assertEqual(arg_info["module_fqn"], module_fqn)
                    self.assertEqual(arg_info["tensor_name"], tensor_name)
                    self.assertEqual(arg_info["tensor_fqn"], tensor_fqn)

    def test_get_arg_info_from_tensor_fqn_fail(self):
        """
        Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn
        inputs. The string outputs still work but the output module is expected to be None.
        """
        for model_class in model_list:
            model = model_class()
            tensor_fqn = "foo.bar.baz"
            arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn)
            self.assertEqual(arg_info["module"], None)
            self.assertEqual(arg_info["module_fqn"], "foo.bar")
            self.assertEqual(arg_info["tensor_name"], "baz")
            self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz")
