# Owner(s): ["module: unknown"]

import warnings

from torch import nn
from torch.ao.pruning import BaseScheduler, CubicSL, LambdaSL, WeightNormSparsifier
from torch.testing._internal.common_utils import TestCase


class ImplementedScheduler(BaseScheduler):
    def get_sl(self):
        if self.last_epoch > 0:
            return [group["sparsity_level"] * 0.5 for group in self.sparsifier.groups]
        else:
            return list(self.base_sl)


class TestScheduler(TestCase):
    def test_constructor(self):
        model = nn.Sequential(nn.Linear(16, 16))
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        scheduler = ImplementedScheduler(sparsifier)

        assert scheduler.sparsifier is sparsifier
        assert scheduler._step_count == 1
        assert scheduler.base_sl == [sparsifier.groups[0]["sparsity_level"]]

    def test_order_of_steps(self):
        """Checks if the warning is thrown if the scheduler step is called
        before the sparsifier step"""

        model = nn.Sequential(nn.Linear(16, 16))
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        scheduler = ImplementedScheduler(sparsifier)

        # Sparsifier step is not called
        with self.assertWarns(UserWarning):
            scheduler.step()

        # Correct order has no warnings
        # Note: This will trigger if other warnings are present.
        with warnings.catch_warnings(record=True) as w:
            sparsifier.step()
            scheduler.step()
            # Make sure there is no warning related to the base_scheduler
            for warning in w:
                fname = warning.filename
                fname = "/".join(fname.split("/")[-5:])
                assert fname != "torch/ao/sparsity/scheduler/base_scheduler.py"

    def test_step(self):
        model = nn.Sequential(nn.Linear(16, 16))
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        assert sparsifier.groups[0]["sparsity_level"] == 0.5
        scheduler = ImplementedScheduler(sparsifier)
        assert sparsifier.groups[0]["sparsity_level"] == 0.5

        sparsifier.step()
        scheduler.step()
        assert sparsifier.groups[0]["sparsity_level"] == 0.25

    def test_lambda_scheduler(self):
        model = nn.Sequential(nn.Linear(16, 16))
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=None)
        assert sparsifier.groups[0]["sparsity_level"] == 0.5
        scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
        assert sparsifier.groups[0]["sparsity_level"] == 0.0  # Epoch 0
        scheduler.step()
        assert sparsifier.groups[0]["sparsity_level"] == 5.0  # Epoch 1


class TestCubicScheduler(TestCase):
    def setUp(self):
        self.model_sparse_config = [
            {"tensor_fqn": "0.weight", "sparsity_level": 0.8},
            {"tensor_fqn": "2.weight", "sparsity_level": 0.4},
        ]
        self.sorted_sparse_levels = [
            conf["sparsity_level"] for conf in self.model_sparse_config
        ]
        self.initial_sparsity = 0.1
        self.initial_step = 3

    def _make_model(self, **kwargs):
        model = nn.Sequential(
            nn.Linear(13, 17),
            nn.Dropout(0.5),
            nn.Linear(17, 3),
        )
        return model

    def _make_scheduler(self, model, **kwargs):
        sparsifier = WeightNormSparsifier()
        sparsifier.prepare(model, config=self.model_sparse_config)

        scheduler_args = {
            "init_sl": self.initial_sparsity,
            "init_t": self.initial_step,
        }
        scheduler_args.update(kwargs)

        scheduler = CubicSL(sparsifier, **scheduler_args)
        return sparsifier, scheduler

    @staticmethod
    def _get_sparsity_levels(sparsifier, precision=32):
        r"""Gets the current levels of sparsity in a sparsifier."""
        return [
            round(group["sparsity_level"], precision) for group in sparsifier.groups
        ]

    def test_constructor(self):
        model = self._make_model()
        sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
        self.assertIs(
            scheduler.sparsifier, sparsifier, msg="Sparsifier is not properly attached"
        )
        self.assertEqual(
            scheduler._step_count,
            1,
            msg="Scheduler is initialized with incorrect step count",
        )
        self.assertEqual(
            scheduler.base_sl,
            self.sorted_sparse_levels,
            msg="Scheduler did not store the target sparsity levels correctly",
        )

        # Value before t_0 is 0
        self.assertEqual(
            self._get_sparsity_levels(sparsifier),
            scheduler._make_sure_a_list(0.0),
            msg="Sparsifier is not reset correctly after attaching to the Scheduler",
        )

        # Value before t_0 is s_0
        model = self._make_model()
        sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
        self.assertEqual(
            self._get_sparsity_levels(sparsifier),
            scheduler._make_sure_a_list(self.initial_sparsity),
            msg="Sparsifier is not reset correctly after attaching to the Scheduler",
        )

    def test_step(self):
        # For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
        model = self._make_model()
        sparsifier, scheduler = self._make_scheduler(
            model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5
        )

        scheduler.step()
        scheduler.step()
        self.assertEqual(
            scheduler._step_count,
            3,
            msg="Scheduler step_count is expected to increment",
        )
        # Value before t_0 is supposed to be 0
        self.assertEqual(
            self._get_sparsity_levels(sparsifier),
            scheduler._make_sure_a_list(0.0),
            msg="Scheduler step updating the sparsity level before t_0",
        )

        scheduler.step()  # Step = 3  =>  sparsity = initial_sparsity
        self.assertEqual(
            self._get_sparsity_levels(sparsifier),
            scheduler._make_sure_a_list(self.initial_sparsity),
            msg="Sparsifier is not reset to initial sparsity at the first step",
        )

        scheduler.step()  # Step = 4  =>  sparsity ~ [0.3, 0.2]
        self.assertEqual(
            self._get_sparsity_levels(sparsifier, 1),
            [0.3, 0.2],
            msg="Sparsity level is not set correctly after the first step",
        )

        current_step = scheduler._step_count - scheduler.init_t[0] - 1
        more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
        for _ in range(more_steps_needed):  # More steps needed to final sparsity level
            scheduler.step()
        self.assertEqual(
            self._get_sparsity_levels(sparsifier),
            self.sorted_sparse_levels,
            msg="Sparsity level is not reaching the target level afer delta_t * n steps ",
        )
