# Owner(s): ["module: unknown"]

import copy
import logging
import warnings
from typing import Tuple

import torch
from torch import nn
from torch.ao.pruning._experimental.data_scheduler import BaseDataScheduler
from torch.ao.pruning._experimental.data_sparsifier import DataNormSparsifier
from torch.testing._internal.common_utils import TestCase


logging.basicConfig(
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


class ImplementedDataScheduler(BaseDataScheduler):
    def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False):
        super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose)

    def get_schedule_param(self):
        if self.last_epoch > 0:
            return {
                name: config["sparsity_level"] * 0.5
                for name, config in self.data_sparsifier.data_groups.items()
            }
        else:
            return self.base_param


class TestBaseDataScheduler(TestCase):
    def _get_data(self):
        tensor1, param1, emb1 = (
            torch.randn(5, 5),
            nn.Parameter(torch.randn(10, 10)),
            nn.Embedding(50, 5),
        )
        data_list = [("tensor1", tensor1), ("param1", param1), ("emb1", emb1)]
        defaults = {
            "sparsity_level": 0.7,
            "sparse_block_shape": (1, 4),
            "zeros_per_block": 2,
        }
        data_with_config = [
            {
                "name": "tensor2",
                "data": torch.randn(4, 4),
                "config": {"sparsity_level": 0.3},
            }
        ]
        return data_list, data_with_config, defaults

    def _get_sparsifier(self, data_list, data_with_config, defaults):
        sparsifier = DataNormSparsifier(data_list, **defaults)
        for data_config_dict in data_with_config:
            name, data, config = (
                data_config_dict["name"],
                data_config_dict["data"],
                data_config_dict["config"],
            )
            sparsifier.add_data(name=name, data=data, **config)
        return sparsifier

    def _get_scheduler(self, sparsifier, schedule_param):
        scheduler = ImplementedDataScheduler(sparsifier, schedule_param)
        return scheduler

    def _get_schedule_param(self):
        return "sparsity_level"

    def _get_name_data_config(self, some_data, defaults):
        config = copy.deepcopy(defaults)
        if isinstance(some_data, Tuple):
            # dealing with data_list
            name, data = some_data
        else:
            # dealing with data_with_config
            name, data, new_config = (
                some_data["name"],
                some_data["data"],
                some_data["config"],
            )
            config.update(new_config)
        return name, data, config

    def test_constructor(self):
        """Checks if the warning is thrown if the scheduler step is called
        before the sparsifier step"""
        data_list, data_with_config, defaults = self._get_data()
        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
        schedule_param = self._get_schedule_param()
        scheduler = self._get_scheduler(sparsifier, schedule_param)

        assert scheduler.data_sparsifier == sparsifier
        assert scheduler._step_count == 1

        for name, config in sparsifier.data_groups.items():
            assert scheduler.base_param[name] == config.get(schedule_param, None)

    def test_order_of_steps(self):
        data_list, data_with_config, defaults = self._get_data()
        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
        schedule_param = self._get_schedule_param()
        scheduler = self._get_scheduler(sparsifier, schedule_param)

        # Sparsifier step is not called
        with self.assertWarns(UserWarning):
            scheduler.step()

        # Correct order has no warnings
        # Note: This will trigger if other warnings are present.
        with warnings.catch_warnings(record=True) as w:
            sparsifier.step()
            scheduler.step()
            # Make sure there is no warning related to the base_data_scheduler
            for warning in w:
                fname = warning.filename
                fname = "/".join(fname.split("/")[-5:])
                assert (
                    fname
                    != "torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py"
                )

    def test_step(self):
        data_list, data_with_config, defaults = self._get_data()
        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
        schedule_param = self._get_schedule_param()
        scheduler = self._get_scheduler(sparsifier, schedule_param)

        all_data = data_list + data_with_config

        for some_data in all_data:
            name, _, config = self._get_name_data_config(some_data, defaults)
            assert (
                sparsifier.data_groups[name][schedule_param] == config[schedule_param]
            )

        sparsifier.step()
        scheduler.step()

        for some_data in all_data:
            name, _, config = self._get_name_data_config(some_data, defaults)
            assert (
                sparsifier.data_groups[name][schedule_param]
                == config[schedule_param] * 0.5
            )

        # checking step count
        step_cnt = 5
        for _ in range(0, step_cnt):
            sparsifier.step()
            scheduler.step()

        assert (
            scheduler._step_count == step_cnt + 2
        )  # step_cnt + step above + 1 step in constructor

    def test_state_dict(self):
        data_list, data_with_config, defaults = self._get_data()
        sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
        schedule_param = self._get_schedule_param()
        scheduler1 = self._get_scheduler(sparsifier, schedule_param)

        sparsifier.step()
        scheduler1.step()

        scheduler2 = self._get_scheduler(sparsifier, schedule_param)
        all_data = data_list + data_with_config
        for some_data in all_data:
            name, _, _ = self._get_name_data_config(some_data, defaults)
            assert scheduler1.base_param[name] != scheduler2.base_param[name]
            assert scheduler1._last_param[name] == scheduler2.base_param[name]

        scheduler1_state = scheduler1.state_dict()
        scheduler2.load_state_dict(scheduler1_state)

        for some_data in all_data:
            name, _, _ = self._get_name_data_config(some_data, defaults)
            assert scheduler1.base_param[name] == scheduler2.base_param[name]
            assert scheduler1._last_param[name] == scheduler2._last_param[name]
