# Owner(s): ["module: optimizer", "module: LrScheduler" ]
import copy
import math
import pickle
import tempfile
import types
import warnings
from functools import partial

import torch
import torch.nn.functional as F
from torch.nn import Parameter
from torch.optim import Adam, Rprop, SGD
from torch.optim.lr_scheduler import (
    ChainedScheduler,
    ConstantLR,
    CosineAnnealingLR,
    CosineAnnealingWarmRestarts,
    CyclicLR,
    EPOCH_DEPRECATION_WARNING,
    ExponentialLR,
    LambdaLR,
    LinearLR,
    LRScheduler,
    MultiplicativeLR,
    MultiStepLR,
    OneCycleLR,
    PolynomialLR,
    ReduceLROnPlateau,
    SequentialLR,
    StepLR,
)
from torch.optim.swa_utils import SWALR
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    load_tests,
    parametrize,
    skipIfTorchDynamo,
    TestCase,
)


# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests


class TestLRScheduler(TestCase):
    class SchedulerTestNet(torch.nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.conv1 = torch.nn.Conv2d(1, 1, 1)
            self.conv2 = torch.nn.Conv2d(1, 1, 1)

        def forward(self, x):
            return self.conv2(F.relu(self.conv1(x)))

    class LambdaLRTestObject:
        def __init__(self, value):
            self.value = value

        def __call__(self, epoch):
            return self.value * epoch

        def __eq__(self, other):
            if isinstance(other, self.__class__):
                return self.__dict__ == other.__dict__
            else:
                return False

    exact_dtype = True

    def setUp(self):
        super().setUp()
        self.net = self.SchedulerTestNet()
        self.opt = SGD(
            [
                {"params": self.net.conv1.parameters()},
                {"params": self.net.conv2.parameters(), "lr": 0.5},
            ],
            lr=0.05,
        )

    def _check_warning_is_epoch_deprecation_warning(self, w, *, num_warnings: int = 1):
        """This function swallows the epoch deprecation warning which is produced when we
        call `scheduler.step(epoch)` with some not `None` value of `epoch`.
        this is deprecated, and this function will need to be removed/updated when
        the schedulers no longer accept the parameter at all.
        """
        self.assertEqual(len(w), num_warnings)
        for warning in w:
            self.assertEqual(len(warning.message.args), 1)
            self.assertEqual(warning.message.args[0], EPOCH_DEPRECATION_WARNING)

    def test_error_when_getlr_has_epoch(self):
        class MultiStepLR(torch.optim.lr_scheduler.LRScheduler):
            def __init__(self, optimizer, gamma, milestones, last_epoch=-1):
                self.init_lr = [group["lr"] for group in optimizer.param_groups]
                self.gamma = gamma
                self.milestones = milestones
                super().__init__(optimizer, last_epoch)

            def get_lr(self, step):
                global_step = self.last_epoch
                gamma_power = (
                    [0]
                    + [i + 1 for i, m in enumerate(self.milestones) if global_step >= m]
                )[-1]
                return [
                    init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr
                ]

        optimizer = SGD([torch.rand(1)], lr=1)

        with self.assertRaises(TypeError):
            scheduler = MultiStepLR(optimizer, gamma=1, milestones=[10, 20])

    @skipIfTorchDynamo(
        "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames"
    )
    def test_no_cyclic_references(self):
        import gc

        param = Parameter(torch.empty(10))
        optim = SGD([param], lr=0.5)
        scheduler = LambdaLR(optim, lambda epoch: 1.0)
        del scheduler

        self.assertTrue(
            len(gc.get_referrers(optim)) == 0,
            "Optimizer should contain no cyclic references",
        )

        gc.collect()
        del optim
        self.assertEqual(
            gc.collect(), 0, msg="Optimizer should be garbage-collected on __del__"
        )

    @skipIfTorchDynamo(
        "Torchdynamo keeps references to optim in the guards and the stack of the graph break frames"
    )
    def test_no_cyclic_references_in_step(self):
        import gc
        import weakref

        def run():
            param = torch.empty(10, requires_grad=True)
            optim = SGD(params=[param], lr=0.5)
            scheduler = LambdaLR(optim, lambda epoch: 1.0)
            param.sum().backward()
            optim.step()
            scheduler.step()

            return weakref.ref(scheduler)

        # To ensure that there are no reference cycles in scheduler,
        # we need to turn off the garbage collector. Since gc will
        # automatically collect unreachable objects.
        gc.disable()
        ref = run()

        assert ref() is None
        gc.enable()  # restore

    def test_old_pattern_warning(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)

    def test_old_pattern_warning_with_arg(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern2():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)

    def test_old_pattern_warning_resuming(self):
        epochs = 35
        for i, group in enumerate(self.opt.param_groups):
            group["initial_lr"] = 0.01

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern)

    def test_old_pattern_warning_resuming_with_arg(self):
        epochs = 35
        for i, group in enumerate(self.opt.param_groups):
            group["initial_lr"] = 0.01

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        def old_pattern2():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)

    def test_old_pattern_warning_with_overridden_optim_step(self):
        epochs = 35
        for i, group in enumerate(self.opt.param_groups):
            group["initial_lr"] = 0.01

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3, last_epoch=10)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        # emulate use-case with optimizer.step overridden
        import types

        old_step = self.opt.step

        def new_step(o, *args, **kwargs):
            retval = old_step(*args, **kwargs)
            return retval

        self.opt.step = types.MethodType(new_step, self.opt)

        def old_pattern2():
            for _ in range(epochs):
                scheduler.step()
                self.opt.step()

        self.assertWarnsRegex(UserWarning, r"how-to-adjust-learning-rate", old_pattern2)

    def test_new_pattern_no_warning(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            for _ in range(epochs):
                self.opt.step()
                scheduler.step()
            self.assertTrue(len(ws) == 0, "No warning should be raised")

    def test_new_pattern_no_warning_with_arg(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            for _ in range(epochs):
                self.opt.step()
                scheduler.step()
            self.assertTrue(len(ws) == 0, "No warning should be raised")

    def test_new_pattern_no_warning_with_overridden_optim_step(self):
        epochs = 35
        with warnings.catch_warnings(record=True) as ws:
            warnings.simplefilter("always")  # allow any warning to be raised
            scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
            self.assertTrue(len(ws) == 0, "No warning should be raised")

        # emulate use-case with optimizer.step overridden
        import types

        old_step = self.opt.step

        def new_step(o, *args, **kwargs):
            retval = old_step(*args, **kwargs)
            return retval

        self.opt.step = types.MethodType(new_step, self.opt)

        def new_pattern():
            for e in range(epochs):
                self.opt.step()
                scheduler.step()

        self.assertWarnsRegex(
            UserWarning, r"`optimizer.step\(\)` has been overridden", new_pattern
        )

    def _test_lr_is_constant_for_constant_epoch(self, scheduler):
        l = []

        for _ in range(10):
            scheduler.optimizer.step()
            with warnings.catch_warnings(record=True) as w:
                scheduler.step(2)
                self._check_warning_is_epoch_deprecation_warning(w)

            l.append(self.opt.param_groups[0]["lr"])
        self.assertEqual(min(l), max(l))

    def test_step_lr_is_constant_for_constant_epoch(self):
        scheduler = StepLR(self.opt, 2)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_exponential_lr_is_constant_for_constant_epoch(self):
        scheduler = ExponentialLR(self.opt, gamma=0.9)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_constantlr_is_constant_for_constant_epoch(self):
        scheduler = ConstantLR(self.opt)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_linear_linearlr_is_constant_for_constant_epoch(self):
        scheduler = LinearLR(self.opt)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_polynomial_lr_is_constant_for_constant_epoch(self):
        scheduler = PolynomialLR(self.opt, power=0.9)
        self._test_lr_is_constant_for_constant_epoch(scheduler)

    def test_step_lr(self):
        # lr = 0.05     if epoch < 3
        # lr = 0.005    if 30 <= epoch < 6
        # lr = 0.0005   if epoch >= 9
        epochs = 10
        single_targets = [0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test(scheduler, targets, epochs)

    def test_get_last_lr_step_lr(self):
        from torch.nn import Parameter

        epochs = 10
        optimizer = SGD([Parameter(torch.randn(2, 2, requires_grad=True))], 0.1)
        targets = [[0.1] * 3 + [0.01] * 3 + [0.001] * 3 + [0.0001]]
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 3, gamma=0.1)
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_get_last_lr_multi_step_lr(self):
        # lr = 0.05     if epoch < 2
        # lr = 0.005    if 2 <= epoch < 5
        # lr = 0.0005   if 5 <= epoch < 9
        # lr = 0.00005   if 9 <= epoch
        epochs = 10
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 1
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_multi_step_lr(self):
        # lr = 0.05     if epoch < 2
        # lr = 0.005    if 2 <= epoch < 5
        # lr = 0.0005   if epoch < 9
        # lr = 0.00005   if epoch >= 9
        epochs = 10
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test(scheduler, targets, epochs)

    def test_multi_step_lr_with_epoch(self):
        # lr = 0.05     if epoch < 2
        # lr = 0.005    if 2 <= epoch < 5
        # lr = 0.0005   if epoch < 9
        # lr = 0.00005   if epoch >= 9
        epochs = 10
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 3
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test_with_epoch(scheduler, targets, epochs)

    def test_get_last_lr_constantlr(self):
        # lr = 0.025     if epoch < 5
        # lr = 0.005    if 5 <= epoch
        epochs = 10
        single_targets = [0.025] * 5 + [0.05] * 5
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_get_last_lr_linearlr(self):
        # lr = 0.025     if epoch == 0
        # lr = 0.03125   if epoch == 1
        # lr = 0.0375    if epoch == 2
        # lr = 0.04375   if epoch == 3
        # lr = 0.005     if 4 <= epoch
        epochs = 10
        start_factor = 1.0 / 4
        end_factor = 3.0 / 5
        iters = 4
        interpolation = [
            start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
        ]
        single_targets = [x * 0.05 for x in interpolation] + [0.05 * end_factor] * (
            epochs - iters
        )
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = LinearLR(
            self.opt,
            start_factor=start_factor,
            end_factor=end_factor,
            total_iters=iters,
        )
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_constantlr(self):
        # lr = 0.025     if epoch < 5
        # lr = 0.005    if 5 <= epoch
        epochs = 10
        single_targets = [0.025] * 5 + [0.05] * 5
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
        self._test(scheduler, targets, epochs)

    def test_linearlr(self):
        # lr = 0.025     if epoch == 0
        # lr = 0.03125   if epoch == 1
        # lr = 0.0375    if epoch == 2
        # lr = 0.04375   if epoch == 3
        # lr = 0.005     if 4 <= epoch
        epochs = 10
        start_factor = 1.0 / 2
        iters = 4
        interpolation = [
            start_factor + i * (1 - start_factor) / iters for i in range(iters)
        ]
        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test(scheduler, targets, epochs)

    def test_linearlr_start_factor_limits1(self):
        start_factor = 0.0
        iters = 4
        with self.assertRaises(ValueError):
            LinearLR(self.opt, start_factor=start_factor, total_iters=iters)

    def test_linearlr_start_factor_limits2(self):
        start_factor = 1.1
        iters = 4
        with self.assertRaises(ValueError):
            LinearLR(self.opt, start_factor=start_factor, total_iters=iters)

    def test_constantlr_with_epoch(self):
        # lr = 0.025     if epoch < 5
        # lr = 0.005    if 5 <= epoch
        epochs = 10
        single_targets = [0.025] * 5 + [0.05] * 5
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ConstantLR(self.opt, factor=1.0 / 2, total_iters=5)
        self._test_with_epoch(scheduler, targets, epochs)

    def test_linearlr_with_epoch(self):
        # lr = 0.025     if epoch == 0
        # lr = 0.03125   if epoch == 1
        # lr = 0.0375    if epoch == 2
        # lr = 0.04375   if epoch == 3
        # lr = 0.005     if 4 <= epoch
        epochs = 10
        start_factor = 1.0 / 2
        end_factor = 1.0
        iters = 4
        interpolation = [
            start_factor + i * (end_factor - start_factor) / iters for i in range(iters)
        ]
        single_targets = [x * 0.05 for x in interpolation] + [0.05] * (epochs - iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test_with_epoch(scheduler, targets, epochs)

    def test_exp_lr(self):
        epochs = 10
        single_targets = [0.05 * (0.9**x) for x in range(epochs)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = ExponentialLR(self.opt, gamma=0.9)
        self._test(scheduler, targets, epochs)

    def test_poly_lr(self):
        epochs = 10
        power = 0.9
        total_iters = 5
        single_targets = [
            (1.0 - x / total_iters) ** power * 0.05 for x in range(total_iters)
        ] + [0.0] * (epochs - total_iters)
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = PolynomialLR(self.opt, power=power, total_iters=total_iters)
        self._test(scheduler, targets, epochs)

    def test_cos_anneal_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [
            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
            for x in range(epochs)
        ]
        targets = [single_targets, [x * epochs for x in single_targets]]
        scheduler = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        self._test(scheduler, targets, epochs)

    def test_closed_form_step_lr(self):
        scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
        closed_form_scheduler = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_linearlr(self):
        scheduler = LinearLR(
            self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
        )
        closed_form_scheduler = LinearLR(
            self.opt, start_factor=1.0 / 3, end_factor=0.7, total_iters=4
        )
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_constantlr(self):
        scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
        closed_form_scheduler = ConstantLR(self.opt, factor=1.0 / 3, total_iters=4)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_multi_step_lr(self):
        scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        closed_form_scheduler = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_exp_lr(self):
        scheduler = ExponentialLR(self.opt, gamma=0.9)
        closed_form_scheduler = ExponentialLR(self.opt, gamma=0.9)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_poly_lr(self):
        scheduler = PolynomialLR(self.opt, power=0.9)
        closed_form_scheduler = PolynomialLR(self.opt, power=0.9)
        self._test_against_closed_form(scheduler, closed_form_scheduler, 20)

    def test_closed_form_cos_anneal_lr(self):
        eta_min = 1e-10
        epochs = 20
        T_max = 5
        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
        closed_form_scheduler = CosineAnnealingLR(
            self.opt, T_max=T_max, eta_min=eta_min
        )
        self._test_against_closed_form(scheduler, closed_form_scheduler, epochs)

    def test_cos_anneal_lr_continue(self):
        eta_min = 0.1
        T_max = 5
        scheduler = CosineAnnealingLR(self.opt, T_max=T_max, eta_min=eta_min)
        self.opt.step()
        scheduler.step()
        original_lrs = scheduler._last_lr
        new_scheduler = CosineAnnealingLR(
            self.opt, T_max=T_max, eta_min=eta_min, last_epoch=0
        )
        new_lrs = new_scheduler._last_lr
        torch.testing.assert_close(original_lrs, new_lrs, rtol=1e-4, atol=1e-5)

    def test_reduce_lr_on_plateau1(self):
        epochs = 10
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 20]
        metrics = [10 - i * 0.0167 for i in range(20)]
        scheduler = ReduceLROnPlateau(
            self.opt,
            threshold_mode="abs",
            mode="min",
            threshold=0.01,
            patience=5,
            cooldown=5,
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau2(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2]
        metrics = [10 - i * 0.0165 for i in range(22)]
        scheduler = ReduceLROnPlateau(
            self.opt,
            patience=5,
            cooldown=0,
            threshold_mode="abs",
            mode="min",
            threshold=0.1,
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau3(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4]
        metrics = [-0.8] * 2 + [-0.234] * 20
        scheduler = ReduceLROnPlateau(
            self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau4(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 20]
        metrics = [1.5 * (1.025**i) for i in range(20)]  # 1.025 > 1.1**0.25
        scheduler = ReduceLROnPlateau(
            self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau5(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
        metrics = [1.5 * (1.005**i) for i in range(20)]
        scheduler = ReduceLROnPlateau(
            self.opt,
            mode="max",
            threshold_mode="rel",
            threshold=0.1,
            patience=5,
            cooldown=5,
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau6(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 20]
        metrics = [1.5 * (0.85**i) for i in range(20)]
        scheduler = ReduceLROnPlateau(
            self.opt, mode="min", threshold_mode="rel", threshold=0.1
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau7(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4]
        metrics = [1] * 7 + [0.6] + [0.5] * 12
        scheduler = ReduceLROnPlateau(
            self.opt,
            mode="min",
            threshold_mode="rel",
            threshold=0.1,
            patience=5,
            cooldown=5,
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau8(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        targets = [[0.5] * 6 + [0.4] * 14, [0.5] * 6 + [0.3] * 14]
        metrics = [1.5 * (1.005**i) for i in range(20)]
        scheduler = ReduceLROnPlateau(
            self.opt,
            mode="max",
            threshold_mode="rel",
            min_lr=[0.4, 0.3],
            threshold=0.1,
            patience=5,
            cooldown=5,
        )
        self._test_reduce_lr_on_plateau(scheduler, targets, metrics, epochs)

    def test_reduce_lr_on_plateau_get_last_lr_before_step(self):
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        scheduler = ReduceLROnPlateau(
            self.opt,
        )
        self.assertEqual(
            scheduler.get_last_lr(), [0.5 for param_group in self.opt.param_groups]
        )

    def test_sequentiallr1(self):
        epochs = 19
        schedulers = [None] * 2
        targets = [
            [0.05, 0.04, 0.032]
            + [0.05 for x in range(4)]
            + [0.05 * 0.1 for x in range(4)]
            + [0.05 * 0.01 for x in range(4)]
            + [0.05 * 0.001 for x in range(4)]
        ]
        milestones = [3]
        schedulers[0] = ExponentialLR(self.opt, gamma=0.8)
        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=4)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        self._test(scheduler, targets, epochs)

    def test_sequentiallr2(self):
        epochs = 13
        schedulers = [None] * 2
        targets = [[0.005, 0.005, 0.005] + [0.05 * 0.9**x for x in range(10)]]
        milestones = [3]
        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        self._test(scheduler, targets, epochs)

    def test_sequentiallr3(self):
        epochs = 12
        schedulers = [None] * 3
        targets = [
            [0.005, 0.005, 0.005]
            + [0.05, 0.04, 0.032]
            + [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
        ]
        milestones = [3, 6]
        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        self._test(scheduler, targets, epochs)

    def test_sequentiallr4(self):
        optimizer = SGD([torch.tensor(0.5)], lr=0.1)
        prev_lr = optimizer.param_groups[0]["lr"]

        schedulers = [
            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1),
            torch.optim.lr_scheduler.ConstantLR(optimizer, factor=0.1),
        ]
        scheduler = torch.optim.lr_scheduler.SequentialLR(
            optimizer, schedulers, milestones=[10]
        )

        new_lr = optimizer.param_groups[0]["lr"]

        # Ensure that multiple schedulers does not affect the initial learning rate
        self.assertEqual(prev_lr, new_lr)

    def test_get_last_lr_sequentiallr(self):
        epochs = 12
        milestones = [3, 6]
        schedulers = [None] * 3
        schedulers[0] = ConstantLR(self.opt, factor=0.1, total_iters=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.8)
        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=2)
        scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones)
        constant_lr_target = [0.005] * 3
        exponential_lr_target = [0.05, 0.04, 0.032]
        step_lr_target = [0.05, 0.05, 0.005, 0.005, 0.0005, 0.0005]
        single_targets = constant_lr_target + exponential_lr_target + step_lr_target
        targets = [single_targets, [x * 10 for x in single_targets]]
        self._test_get_last_lr(scheduler, targets, epochs)

    def test_chained_lr2_get_last_lr_before_step(self):
        schedulers = [
            LinearLR(self.opt, start_factor=0.4, total_iters=3),
            MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1),
        ]
        scheduler = ChainedScheduler(schedulers)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr1(self):
        epochs = 10
        schedulers = [None] * 1
        targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005] * 3]
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr2(self):
        epochs = 10
        schedulers = [None] * 1
        targets = [[0.02, 0.03, 0.04] + [0.05] * 9]
        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr3(self):
        epochs = 10
        schedulers = [None] * 2
        targets = [
            [0.02, 0.03, 0.04, 0.05] + [0.005] * 4 + [0.0005] * 3 + [0.00005] * 3
        ]
        schedulers[0] = LinearLR(self.opt, start_factor=0.4, total_iters=3)
        schedulers[1] = MultiStepLR(self.opt, milestones=[4, 8, 10], gamma=0.1)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr4(self):
        epochs = 9
        schedulers = [None] * 3
        targets = [
            [0.05 * 0.2 * 0.9**x for x in range(3)]
            + [0.05 * 0.2 * 0.9**3 * 0.1]
            + [0.05 * 0.9**x * 0.1 for x in range(4, 6)]
            + [0.05 * 0.9**x * 0.01 for x in range(6, 9)]
        ]
        schedulers[0] = ExponentialLR(self.opt, gamma=0.9)
        schedulers[1] = ConstantLR(self.opt, factor=0.2, total_iters=4)
        schedulers[2] = StepLR(self.opt, gamma=0.1, step_size=3)
        scheduler = ChainedScheduler(schedulers)
        self._test([scheduler], targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_chained_lr5(self):
        def poly_lr(lr: float):
            return [
                (lr * ((1.0 - x / total_iters) ** power)) for x in range(total_iters)
            ] + [0.0] * (epochs - total_iters)

        schedulers = [None] * 2
        epochs = 10
        power = 0.9
        total_iters = 5
        const_factor = 0.1
        single_targets = [x * const_factor for x in poly_lr(lr=0.05)]
        targets = [single_targets, [x * const_factor for x in poly_lr(0.5)]]
        schedulers[0] = PolynomialLR(self.opt, power=power, total_iters=total_iters)
        schedulers[1] = ConstantLR(self.opt, factor=const_factor)
        scheduler = ChainedScheduler(schedulers)
        self._test(scheduler, targets, epochs)
        self.assertEqual(scheduler.get_last_lr(), schedulers[-1].get_last_lr())

    def test_compound_step_and_multistep_lr(self):
        epochs = 10
        schedulers = [None] * 2
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        targets = [[0.05] * 2 + [0.005] * 1 + [5e-4] * 2 + [5e-5] + [5e-6] * 3 + [5e-8]]
        self._test(schedulers, targets, epochs)

    def test_compound_step_and_exp_lr(self):
        epochs = 10
        schedulers = [None] * 2
        single_targets = [0.05 * (0.9**x) for x in range(3)]
        single_targets += [0.005 * (0.9**x) for x in range(3, 6)]
        single_targets += [0.0005 * (0.9**x) for x in range(6, 9)]
        single_targets += [0.00005 * (0.9**x) for x in range(9, 12)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        self._test(schedulers, targets, epochs)

    def test_compound_exp_and_multistep_lr(self):
        epochs = 10
        schedulers = [None] * 2
        single_targets = [0.05 * (0.9**x) for x in range(2)]
        single_targets += [0.005 * (0.9**x) for x in range(2, 5)]
        single_targets += [0.0005 * (0.9**x) for x in range(5, 9)]
        single_targets += [0.00005 * (0.9**x) for x in range(9, 11)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        self._test(schedulers, targets, epochs)

    def test_compound_exp_and_linearlr(self):
        epochs = 10
        iters = 4
        start_factor = 0.4
        end_factor = 0.9
        schedulers = [None] * 2
        single_targets = [0.05 * (0.9**x) for x in range(11)]
        for i in range(iters):
            single_targets[i] *= start_factor + i / iters * (end_factor - start_factor)
        for i in range(iters, 11):
            single_targets[i] *= end_factor
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = LinearLR(
            self.opt,
            start_factor=start_factor,
            end_factor=end_factor,
            total_iters=iters,
        )
        schedulers[1] = ExponentialLR(self.opt, gamma=0.9)
        self._test(schedulers, targets, epochs)

    def test_compound_step_and_constantlr(self):
        epochs = 10
        iters = 4
        factor = 0.4
        schedulers = [None] * 2
        single_targets = (
            [0.05 * 0.4] * 3
            + [0.005 * 0.4]
            + [0.005] * 2
            + [0.0005] * 3
            + [0.00005] * 3
        )
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = StepLR(self.opt, gamma=0.1, step_size=3)
        schedulers[1] = ConstantLR(self.opt, factor=0.4, total_iters=4)
        self._test(schedulers, targets, epochs)

    def test_compound_linearlr_and_multistep_lr(self):
        epochs = 10
        iters = 4
        start_factor = 0.4
        schedulers = [None] * 2
        single_targets = [0.05] * 2 + [0.005] * 3 + [0.0005] * 4 + [0.00005] * 2
        for i in range(iters):
            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_step_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [
            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
            for x in range(epochs)
        ]
        single_targets = [x * 0.1 ** (i // 3) for i, x in enumerate(single_targets)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers = [None] * 2
        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_multistep_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [
            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
            for x in range(epochs)
        ]
        multipliers = [1] * 2 + [0.1] * 3 + [0.01] * 4 + [0.001]
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers = [None] * 2
        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9])
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_linearlr(self):
        epochs = 10
        iters = 4
        start_factor = 0.4
        eta_min = 1e-10
        schedulers = [None] * 2
        single_targets = [
            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
            for x in range(epochs)
        ]
        for i in range(iters):
            single_targets[i] *= start_factor + i / iters * (1 - start_factor)
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers[0] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        schedulers[1] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        self._test(schedulers, targets, epochs)

    def test_compound_cosanneal_and_exp_lr(self):
        epochs = 10
        eta_min = 1e-10
        single_targets = [
            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
            for x in range(epochs)
        ]
        multipliers = [0.1**i for i in range(epochs)]
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets, [x * epochs for x in single_targets]]
        schedulers = [None] * 2
        schedulers[0] = CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min)
        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
        self._test(schedulers, targets, epochs)

    def test_compound_reduce_lr_on_plateau1(self):
        epochs = 10
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        single_targets = [0.5] * 20
        multipliers = [0.1 ** (i // 3) for i in range(20)]
        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [10 - i * 0.0167 for i in range(20)]
        schedulers = [None, None]
        schedulers[0] = ReduceLROnPlateau(
            self.opt,
            threshold_mode="abs",
            mode="min",
            threshold=0.01,
            patience=5,
            cooldown=5,
        )
        schedulers[1] = StepLR(self.opt, gamma=0.1, step_size=3)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau2(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
        multipliers = [1] * 3 + [0.1] * 5 + [0.01] * 4 + [0.001] * 10
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [10 - i * 0.0165 for i in range(22)]
        schedulers = [None] * 2
        schedulers[0] = ReduceLROnPlateau(
            self.opt,
            patience=5,
            cooldown=0,
            threshold_mode="abs",
            mode="min",
            threshold=0.1,
        )
        schedulers[1] = MultiStepLR(self.opt, gamma=0.1, milestones=[3, 8, 12])
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau3(self):
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        single_targets = [0.5] * (2 + 6) + [0.05] * (5 + 6) + [0.005] * 4
        multipliers = [0.1**i for i in range(epochs)]
        single_targets = [x * y for x, y in zip(multipliers, single_targets)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [-0.8] * 2 + [-0.234] * 20
        schedulers = [None, None]
        schedulers[0] = ReduceLROnPlateau(
            self.opt, mode="max", patience=5, cooldown=5, threshold_mode="abs"
        )
        schedulers[1] = ExponentialLR(self.opt, gamma=0.1)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau4(self):
        epochs = 20
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.05
        epochs = 10
        eta_min = 1e-10
        single_targets = [
            eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / epochs)) / 2
            for x in range(epochs)
        ]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [1.5 * (1.025**i) for i in range(20)]  # 1.025 > 1.1**0.25
        schedulers = [None, None]
        schedulers[0] = ReduceLROnPlateau(
            self.opt, mode="max", patience=3, threshold_mode="rel", threshold=0.1
        )
        schedulers[1] = CosineAnnealingLR(self.opt, epochs, eta_min)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_compound_reduce_lr_on_plateau5(self):
        iters = 4
        start_factor = 0.4
        epochs = 22
        for param_group in self.opt.param_groups:
            param_group["lr"] = 0.5
        single_targets = [0.5] * 6 + [0.05] * 7 + [0.005] * 7 + [0.0005] * 2
        multipliers = [1] * 22
        for i in range(iters):
            multipliers[i] *= start_factor + i / iters * (1 - start_factor)
        single_targets = [x * y for x, y in zip(single_targets, multipliers)]
        targets = [single_targets]
        targets = targets[1:]  # test runs step before checking lr
        metrics = [10 - i * 0.0165 for i in range(22)]
        schedulers = [None] * 2
        schedulers[0] = ReduceLROnPlateau(
            self.opt,
            patience=5,
            cooldown=0,
            threshold_mode="abs",
            mode="min",
            threshold=0.1,
        )
        schedulers[1] = LinearLR(self.opt, start_factor=start_factor, total_iters=iters)
        self._test_reduce_lr_on_plateau(schedulers, targets, metrics, epochs)

    def test_cycle_lr_invalid_mode(self):
        with self.assertRaises(ValueError):
            scheduler = CyclicLR(self.opt, base_lr=0, max_lr=0, mode="CATS")

    def test_cycle_lr_triangular_mode_one_lr(self):
        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        momentum_target = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(
            self.opt,
            base_lr=1,
            max_lr=5,
            step_size_up=4,
            cycle_momentum=True,
            base_momentum=1,
            max_momentum=5,
            mode="triangular",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular_mode_one_lr_no_momentum(self):
        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        lr_targets = [lr_target, lr_target]
        momentum_target = [self.opt.defaults["momentum"]] * len(lr_target)
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(
            self.opt,
            base_lr=1,
            max_lr=5,
            step_size_up=4,
            cycle_momentum=False,
            mode="triangular",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular2_mode_one_lr(self):
        lr_target = [
            1,
            2,
            3,
            4,
            5,
            4,
            3,
            2,
            1,
            1.5,
            2.0,
            2.5,
            3.0,
            2.5,
            2.0,
            1.5,
            1,
            1.25,
            1.50,
            1.75,
            2.00,
            1.75,
        ]
        momentum_target = [
            5.0,
            4.0,
            3.0,
            2.0,
            1.0,
            2.0,
            3.0,
            4.0,
            5.0,
            4.5,
            4.0,
            3.5,
            3.0,
            3.5,
            4.0,
            4.5,
            5.0,
            4.75,
            4.5,
            4.25,
            4.0,
            4.25,
        ]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(
            self.opt,
            base_lr=1,
            max_lr=5,
            step_size_up=4,
            cycle_momentum=True,
            base_momentum=1,
            max_momentum=5,
            mode="triangular2",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_exp_range_mode_one_lr(self):
        base_lr, max_lr = 1, 5
        diff_lr = max_lr - base_lr
        gamma = 0.9
        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(
            self.opt,
            base_lr=base_lr,
            max_lr=max_lr,
            step_size_up=4,
            cycle_momentum=True,
            base_momentum=base_lr,
            max_momentum=max_lr,
            mode="exp_range",
            gamma=gamma,
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular_mode(self):
        lr_target_1 = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        lr_target_2 = [x + 1 for x in lr_target_1]
        lr_targets = [lr_target_1, lr_target_2]
        momentum_target_1 = [5, 4, 3, 2, 1, 2, 3, 4, 5, 4, 3]
        momentum_target_2 = [x + 1 for x in momentum_target_1]
        momentum_targets = [momentum_target_1, momentum_target_2]
        scheduler = CyclicLR(
            self.opt,
            base_lr=[1, 2],
            max_lr=[5, 6],
            step_size_up=4,
            cycle_momentum=True,
            base_momentum=[1, 2],
            max_momentum=[5, 6],
            mode="triangular",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))

    def test_cycle_lr_triangular2_mode(self):
        lr_target_1 = [
            1,
            2,
            3,
            4,
            5,
            4,
            3,
            2,
            1,
            1.5,
            2.0,
            2.5,
            3.0,
            2.5,
            2.0,
            1.5,
            1,
            1.25,
            1.50,
            1.75,
            2.00,
            1.75,
        ]
        lr_target_2 = [x + 2 for x in lr_target_1]
        lr_targets = [lr_target_1, lr_target_2]
        momentum_target_1 = [
            5.0,
            4.0,
            3.0,
            2.0,
            1.0,
            2.0,
            3.0,
            4.0,
            5.0,
            4.5,
            4.0,
            3.5,
            3.0,
            3.5,
            4.0,
            4.5,
            5.0,
            4.75,
            4.5,
            4.25,
            4.0,
            4.25,
        ]
        momentum_target_2 = [x + 2 for x in momentum_target_1]
        momentum_targets = [momentum_target_1, momentum_target_2]
        scheduler = CyclicLR(
            self.opt,
            base_lr=[1, 3],
            max_lr=[5, 7],
            step_size_up=4,
            cycle_momentum=True,
            base_momentum=[1, 3],
            max_momentum=[5, 7],
            mode="triangular2",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))

    def test_cycle_lr_exp_range_mode(self):
        base_lr_1, max_lr_1 = 1, 5
        base_lr_2, max_lr_2 = 5, 12

        diff_lr_1 = max_lr_1 - base_lr_1
        diff_lr_2 = max_lr_2 - base_lr_2

        gamma = 0.9
        xs = [0, 0.25, 0.5, 0.75, 1, 0.75, 0.50, 0.25, 0, 0.25, 0.5, 0.75, 1]
        lr_target_1 = [base_lr_1 + x * diff_lr_1 * gamma**i for i, x in enumerate(xs)]
        lr_target_2 = [base_lr_2 + x * diff_lr_2 * gamma**i for i, x in enumerate(xs)]
        lr_targets = [lr_target_1, lr_target_2]
        momentum_target_1 = [
            max_lr_1 - x * diff_lr_1 * gamma**i for i, x in enumerate(xs)
        ]
        momentum_target_2 = [
            max_lr_2 - x * diff_lr_2 * gamma**i for i, x in enumerate(xs)
        ]
        momentum_targets = [momentum_target_1, momentum_target_2]
        scheduler = CyclicLR(
            self.opt,
            base_lr=[base_lr_1, base_lr_2],
            max_lr=[max_lr_1, max_lr_2],
            step_size_up=4,
            cycle_momentum=True,
            base_momentum=[base_lr_1, base_lr_2],
            max_momentum=[max_lr_1, max_lr_2],
            mode="exp_range",
            gamma=gamma,
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target_1))

    def test_cycle_lr_triangular_mode_step_size_up_down(self):
        lr_target = [
            1.0,
            2.0,
            3.0,
            4.0,
            5.0,
            13.0 / 3,
            11.0 / 3,
            9.0 / 3,
            7.0 / 3,
            5.0 / 3,
            1.0,
        ]
        lr_targets = [lr_target, lr_target]
        momentum_target = [
            5.0,
            4.0,
            3.0,
            2.0,
            1.0,
            5.0 / 3,
            7.0 / 3,
            3.0,
            11.0 / 3,
            13.0 / 3,
            5.0,
        ]
        momentum_targets = [momentum_target, momentum_target]

        scheduler = CyclicLR(
            self.opt,
            base_lr=1,
            max_lr=5,
            step_size_up=4,
            step_size_down=6,
            cycle_momentum=True,
            base_momentum=1,
            max_momentum=5,
            mode="triangular",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_triangular2_mode_step_size_up_down(self):
        lr_base_target = [
            1.0,
            3.0,
            5.0,
            13.0 / 3,
            11.0 / 3,
            9.0 / 3,
            7.0 / 3,
            5.0 / 3,
            1.0,
            2.0,
            3.0,
            8.0 / 3,
            7.0 / 3,
            6.0 / 3,
            5.0 / 3,
            4.0 / 3,
            1.0,
            3.0 / 2,
            2.0,
            11.0 / 6,
            10.0 / 6,
            9.0 / 6,
            8.0 / 6,
            7.0 / 6,
        ]
        momentum_base_target = [
            5.0,
            3.0,
            1.0,
            5.0 / 3,
            7.0 / 3,
            3.0,
            11.0 / 3,
            13.0 / 3,
            5.0,
            4.0,
            3.0,
            10.0 / 3,
            11.0 / 3,
            4.0,
            13.0 / 3,
            14.0 / 3,
            5.0,
            4.5,
            4.0,
            25.0 / 6,
            13.0 / 3,
            4.5,
            14.0 / 3,
            29.0 / 6,
        ]
        deltas = [2 * i for i in range(0, 2)]
        base_lrs = [1 + delta for delta in deltas]
        max_lrs = [5 + delta for delta in deltas]
        lr_targets = [[x + delta for x in lr_base_target] for delta in deltas]
        momentum_targets = [
            [x + delta for x in momentum_base_target] for delta in deltas
        ]
        scheduler = CyclicLR(
            self.opt,
            base_lr=base_lrs,
            max_lr=max_lrs,
            step_size_up=2,
            step_size_down=6,
            cycle_momentum=True,
            base_momentum=base_lrs,
            max_momentum=max_lrs,
            mode="triangular2",
        )
        self._test_cycle_lr(
            scheduler, lr_targets, momentum_targets, len(lr_base_target)
        )

    def test_cycle_lr_exp_range_mode_step_size_up_down(self):
        base_lr, max_lr = 1, 5
        diff_lr = max_lr - base_lr
        gamma = 0.9
        xs = [
            0.0,
            0.5,
            1.0,
            5.0 / 6,
            4.0 / 6,
            3.0 / 6,
            2.0 / 6,
            1.0 / 6,
            0.0,
            0.5,
            1.0,
            5.0 / 6,
            4.0 / 6,
        ]
        lr_target = [base_lr + x * diff_lr * gamma**i for i, x in enumerate(xs)]
        lr_targets = [lr_target, lr_target]
        momentum_target = [max_lr - x * diff_lr * gamma**i for i, x in enumerate(xs)]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(
            self.opt,
            base_lr=base_lr,
            max_lr=max_lr,
            step_size_up=2,
            step_size_down=6,
            cycle_momentum=True,
            base_momentum=base_lr,
            max_momentum=max_lr,
            mode="exp_range",
            gamma=gamma,
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

    def test_cycle_lr_with_momentumless_optimizer(self):
        # Note [Temporarily set optimizer to Adam]
        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
        # The TestLRScheduler object carries around an SGD optimizer to avoid having to
        # instantiate one for every test. This gets in the way for our very specific case
        # in which we need to use Adam (or really any optimizer that doesn't use momentum)
        # in order to test that the momentum bug in CyclicLR is fixed (the bug is described
        # in more detail in https://github.com/pytorch/pytorch/issues/19003 ).
        old_opt = self.opt
        self.opt = Adam(
            [
                {"params": self.net.conv1.parameters()},
                {"params": self.net.conv2.parameters(), "lr": 0.5},
            ],
            lr=0.05,
        )

        lr_target = [1, 2, 3, 4, 5, 4, 3, 2, 1, 2, 3]
        lr_targets = [lr_target, lr_target]
        momentum_target = [None] * len(lr_target)
        momentum_targets = [momentum_target, momentum_target]
        scheduler = CyclicLR(
            self.opt,
            base_lr=1,
            max_lr=5,
            step_size_up=4,
            cycle_momentum=False,
            mode="triangular",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, len(lr_target))

        self.opt = old_opt  # set optimizer back to SGD

    def test_cycle_lr_cycle_momentum_fail_with_momentumless_optimizer(self):
        with self.assertRaises(ValueError):
            rprop_opt = Rprop(self.net.parameters())
            scheduler = CyclicLR(rprop_opt, base_lr=1, max_lr=5, cycle_momentum=True)

    def test_cycle_lr_cycle_momentum_with_beta1_optimizer(self):
        adam_opt = Adam(self.net.parameters())
        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=True)

    def test_cycle_lr_removed_after_out_of_scope(self):
        import gc
        import weakref

        gc.disable()

        def test():
            adam_opt = Adam(self.net.parameters())
            scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
            return weakref.ref(scheduler)

        ref = test()
        assert ref() is None
        gc.enable()

    def test_cycle_lr_state_dict_picklable(self):
        adam_opt = Adam(self.net.parameters())

        # Case 1: Built-in mode
        scheduler = CyclicLR(adam_opt, base_lr=1, max_lr=5, cycle_momentum=False)
        self.assertIsInstance(scheduler._scale_fn_ref, types.FunctionType)
        state = scheduler.state_dict()
        self.assertNotIn("_scale_fn_ref", state)
        self.assertIs(state["_scale_fn_custom"], None)
        pickle.dumps(state)

        # Case 2: Custom `scale_fn`, a function object
        def scale_fn(_):
            return 0.5

        scheduler = CyclicLR(
            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
        )
        state = scheduler.state_dict()
        self.assertNotIn("_scale_fn_ref", state)
        self.assertIs(state["_scale_fn_custom"], None)
        pickle.dumps(state)

        # Case 3: Custom `scale_fn`, a callable class
        class ScaleFn:
            def __init__(self) -> None:
                self.x = 0.5

            def __call__(self, _):
                return self.x

        scale_fn = ScaleFn()

        scheduler = CyclicLR(
            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
        )
        state = scheduler.state_dict()
        self.assertNotIn("_scale_fn_ref", state)
        self.assertEqual(state["_scale_fn_custom"], scale_fn.__dict__)
        pickle.dumps(state)

    def test_cycle_lr_scale_fn_restored_from_state_dict(self):
        adam_opt = Adam(self.net.parameters())

        # Case 1: Built-in mode
        scheduler = CyclicLR(
            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, mode="triangular2"
        )
        restored_scheduler = CyclicLR(
            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False
        )
        restored_scheduler.load_state_dict(scheduler.state_dict())
        self.assertTrue(restored_scheduler.mode == scheduler.mode == "triangular2")
        self.assertIsNotNone(restored_scheduler._scale_fn_ref) and self.assertIsNotNone(
            scheduler._scale_fn_ref
        )
        self.assertIs(restored_scheduler._scale_fn_custom, None)
        self.assertIs(scheduler._scale_fn_custom, None)

        # Case 2: Custom `scale_fn`
        def scale_fn(_):
            return 0.5

        scheduler = CyclicLR(
            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
        )
        restored_scheduler = CyclicLR(
            adam_opt, base_lr=1, max_lr=5, cycle_momentum=False, scale_fn=scale_fn
        )
        restored_scheduler.load_state_dict(scheduler.state_dict())
        self.assertIs(scheduler._scale_fn_custom, scale_fn)
        self.assertIs(restored_scheduler._scale_fn_custom, scale_fn)

    def test_onecycle_lr_invalid_anneal_strategy(self):
        with self.assertRaises(ValueError):
            scheduler = OneCycleLR(
                self.opt, max_lr=1e-3, total_steps=10, anneal_strategy="CATS"
            )

    def test_onecycle_lr_invalid_pct_start(self):
        with self.assertRaises(ValueError):
            scheduler = OneCycleLR(self.opt, max_lr=1e-3, total_steps=10, pct_start=1.1)

    def test_onecycle_lr_cannot_calculate_total_steps(self):
        with self.assertRaises(ValueError):
            scheduler = OneCycleLR(self.opt, max_lr=1e-3)

    def test_onecycle_lr_linear_annealing(self):
        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(
            self.opt,
            max_lr=25,
            final_div_factor=2,
            base_momentum=1,
            max_momentum=22,
            total_steps=10,
            anneal_strategy="linear",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_onecycle_lr_linear_annealing_three_phases(self):
        lr_target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
        momentum_target = [22, 15, 8, 1, 8, 15, 22, 22, 22, 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(
            self.opt,
            max_lr=25,
            div_factor=25,
            base_momentum=1,
            max_momentum=22,
            total_steps=10,
            anneal_strategy="linear",
            pct_start=0.4,
            final_div_factor=4,
            three_phase=True,
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_onecycle_lr_cosine_annealing(self):
        def annealing_cos(start, end, pct):
            cos_out = math.cos(math.pi * pct) + 1
            return end + (start - end) / 2.0 * cos_out

        lr_target = [
            1,
            13,
            25,
            annealing_cos(25, 0.5, 1 / 7.0),
            annealing_cos(25, 0.5, 2 / 7.0),
            annealing_cos(25, 0.5, 3 / 7.0),
            annealing_cos(25, 0.5, 4 / 7.0),
            annealing_cos(25, 0.5, 5 / 7.0),
            annealing_cos(25, 0.5, 6 / 7.0),
            0.5,
        ]
        momentum_target = [
            22,
            11.5,
            1,
            annealing_cos(1, 22, 1 / 7.0),
            annealing_cos(1, 22, 2 / 7.0),
            annealing_cos(1, 22, 3 / 7.0),
            annealing_cos(1, 22, 4 / 7.0),
            annealing_cos(1, 22, 5 / 7.0),
            annealing_cos(1, 22, 6 / 7.0),
            22,
        ]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(
            self.opt,
            max_lr=25,
            final_div_factor=2,
            base_momentum=1,
            max_momentum=22,
            total_steps=10,
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_onecycle_lr_legacy_state_dict(self):
        scheduler = OneCycleLR(
            self.opt,
            max_lr=25,
            final_div_factor=2,
            base_momentum=1,
            max_momentum=22,
            total_steps=10,
            anneal_strategy="cos",
        )
        delattr(scheduler, "_anneal_func_type")
        state_dict = scheduler.state_dict()
        self.assertNotIn("anneal_func_type", state_dict)
        state_dict["anneal_func"] = OneCycleLR._annealing_cos
        scheduler.load_state_dict(state_dict)

        def annealing_cos(start, end, pct):
            cos_out = math.cos(math.pi * pct) + 1
            return end + (start - end) / 2.0 * cos_out

        lr_target = [
            1,
            13,
            25,
            annealing_cos(25, 0.5, 1 / 7.0),
            annealing_cos(25, 0.5, 2 / 7.0),
            annealing_cos(25, 0.5, 3 / 7.0),
            annealing_cos(25, 0.5, 4 / 7.0),
            annealing_cos(25, 0.5, 5 / 7.0),
            annealing_cos(25, 0.5, 6 / 7.0),
            0.5,
        ]
        momentum_target = [
            22,
            11.5,
            1,
            annealing_cos(1, 22, 1 / 7.0),
            annealing_cos(1, 22, 2 / 7.0),
            annealing_cos(1, 22, 3 / 7.0),
            annealing_cos(1, 22, 4 / 7.0),
            annealing_cos(1, 22, 5 / 7.0),
            annealing_cos(1, 22, 6 / 7.0),
            22,
        ]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10)

    def test_cycle_lr_with_adam(self):
        old_opt = self.opt
        self.opt = Adam(
            [
                {"params": self.net.conv1.parameters()},
                {"params": self.net.conv2.parameters(), "lr": 0.5},
            ],
            lr=0.05,
        )

        lr_target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
        momentum_target = [22, 11.5, 1, 4, 7, 10, 13, 16, 19, 22]
        lr_targets = [lr_target, lr_target]
        momentum_targets = [momentum_target, momentum_target]
        scheduler = OneCycleLR(
            self.opt,
            max_lr=25,
            final_div_factor=2,
            base_momentum=1,
            max_momentum=22,
            total_steps=10,
            anneal_strategy="linear",
        )
        self._test_cycle_lr(scheduler, lr_targets, momentum_targets, 10, use_beta1=True)
        self.opt = old_opt  # set optimizer back to SGD

    def test_lambda_lr(self):
        epochs = 10
        self.opt.param_groups[0]["lr"] = 0.05
        self.opt.param_groups[1]["lr"] = 0.4
        targets = [
            [0.05 * (0.9**x) for x in range(epochs)],
            [0.4 * (0.8**x) for x in range(epochs)],
        ]
        scheduler = LambdaLR(
            self.opt, lr_lambda=[lambda x1: 0.9**x1, lambda x2: 0.8**x2]
        )
        self._test(scheduler, targets, epochs)

    def test_multiplicative_lr(self):
        epochs = 10
        self.opt.param_groups[0]["lr"] = 0.05
        self.opt.param_groups[1]["lr"] = 0.4
        targets = [
            [0.05 * (0.9**x) for x in range(epochs)],
            [0.4 * (0.8**x) for x in range(epochs)],
        ]
        scheduler = MultiplicativeLR(
            self.opt, lr_lambda=[lambda x1: 0.9, lambda x2: 0.8]
        )
        self._test(scheduler, targets, epochs)

    @parametrize("T_mult", [1, 2, 4])
    def test_CosineAnnealingWarmRestarts_lr1(self, T_mult):
        iters = 100
        eta_min = 1e-10
        T_i = 10
        T_cur = 0
        targets = [[0.05], [0.5]]
        scheduler = CosineAnnealingWarmRestarts(
            self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
        )
        for _ in range(1, iters, 1):
            T_cur += 1
            if T_cur >= T_i:
                T_cur = T_cur - T_i
                T_i = int(T_mult) * T_i
            targets[0] += [
                eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
            ]
            targets[1] += [
                eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
            ]
        self._test(scheduler, targets, iters)

    def test_CosineAnnealingWarmRestarts_lr2(self):
        iters = 30
        eta_min = 1e-10
        T_mults = [1, 2, 4]
        for T_mult in T_mults:
            T_i = 10
            T_cur = 0
            targets = [[0.05], [0.5]]
            scheduler = CosineAnnealingWarmRestarts(
                self.opt, T_0=T_i, T_mult=T_mult, eta_min=eta_min
            )
            for _ in torch.arange(0.1, iters, 0.1):
                T_cur = round(T_cur + 0.1, 1)
                if T_cur >= T_i:
                    T_cur = T_cur - T_i
                    T_i = int(T_mult) * T_i
                targets[0] += [
                    eta_min
                    + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
                ]
                targets[1] += [
                    eta_min
                    + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
                ]
            self._test_CosineAnnealingWarmRestarts(scheduler, targets, iters)

    def test_CosineAnnealingWarmRestarts_lr3(self):
        epochs_for_T_mults = [
            [0, 1, 2, 3, 4, 5, 12, 27, 3, 4, 5, 6, 13],
            [0, 1, 2, 3, 4, 5, 25, 32, 33, 34, 80, 81, 3],
            [0, 0.1, 0.2, 0.3, 1.3, 2.3, 17.5, 18.5, 19.5, 29.5, 30.5, 31.5, 50],
        ]
        T_curs_for_T_mults = [
            [1, 2, 3, 4, 5, 2, 7, 3, 4, 5, 6, 3],
            [1, 2, 3, 4, 5, 15, 2, 3, 4, 10, 11, 3],
            [0.1, 0.2, 0.3, 1.3, 2.3, 7.5, 8.5, 9.5, 19.5, 20.5, 21.5, 10],
        ]
        T_is_for_T_mults = [
            [10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
            [10, 10, 10, 10, 10, 20, 40, 40, 40, 80, 80, 10],
            [10, 10, 10, 10, 10, 30, 30, 30, 30, 30, 30, 90],
        ]
        eta_min = 1e-10
        T_mults = [1, 2, 3]
        for epochs, T_mult, T_curs, T_is in zip(
            epochs_for_T_mults, T_mults, T_curs_for_T_mults, T_is_for_T_mults
        ):
            targets = [[0.05], [0.5]]
            scheduler = CosineAnnealingWarmRestarts(
                self.opt, T_0=10, T_mult=T_mult, eta_min=eta_min
            )
            for T_cur, T_i in zip(T_curs, T_is):
                targets[0] += [
                    eta_min
                    + (0.05 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
                ]
                targets[1] += [
                    eta_min
                    + (0.5 - eta_min) * (1 + math.cos(math.pi * T_cur / T_i)) / 2
                ]
            self._test_interleaved_CosineAnnealingWarmRestarts(
                scheduler, targets, epochs
            )

    def test_swalr_no_anneal(self):
        epochs, swa_start, swa_lr = 10, 5, 0.01
        initial_lrs = [group["lr"] for group in self.opt.param_groups]
        targets = [
            [lr] * (swa_start + 1) + [swa_lr] * (epochs - swa_start - 1)
            for lr in initial_lrs
        ]
        swa_scheduler = SWALR(self.opt, anneal_epochs=1, swa_lr=swa_lr)
        self._test_swalr(swa_scheduler, None, targets, swa_start, epochs)

    def test_swalr_cosine_anneal_after_multiplicative(self):
        # same swa_lr for different param_groups
        epochs, swa_start, swa_lr, anneal_epochs = 15, 5, 0.01, 5
        mult_factor = 0.9
        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
        swa_scheduler = SWALR(self.opt, anneal_epochs=anneal_epochs, swa_lr=swa_lr)

        def anneal_coef(t):
            if t + 1 >= anneal_epochs:
                return 0.0
            return (1 + math.cos(math.pi * (t + 1) / anneal_epochs)) / 2

        initial_lrs = [group["lr"] for group in self.opt.param_groups]
        targets_before_swa = [
            [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
        ]
        swa_epochs = epochs - swa_start - 1
        targets = [
            lrs
            + [
                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
                for t in range(swa_epochs)
            ]
            for lrs in targets_before_swa
        ]

        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)

    def test_swalr_linear_anneal_after_multiplicative(self):
        # separate swa_lr for different param_groups
        epochs, swa_start, swa_lrs, anneal_epochs = 15, 5, [0.01, 0.02], 4
        mult_factor = 0.9
        scheduler = MultiplicativeLR(self.opt, lr_lambda=lambda epoch: mult_factor)
        swa_scheduler = SWALR(
            self.opt,
            anneal_epochs=anneal_epochs,
            anneal_strategy="linear",
            swa_lr=swa_lrs,
        )

        def anneal_coef(t):
            if t + 1 >= anneal_epochs:
                return 0.0
            return 1 - (t + 1) / anneal_epochs

        initial_lrs = [group["lr"] for group in self.opt.param_groups]
        targets_before_swa = [
            [lr * mult_factor**i for i in range(swa_start + 1)] for lr in initial_lrs
        ]
        swa_epochs = epochs - swa_start - 1
        targets = [
            lrs
            + [
                lrs[-1] * anneal_coef(t) + swa_lr * (1 - anneal_coef(t))
                for t in range(swa_epochs)
            ]
            for lrs, swa_lr in zip(targets_before_swa, swa_lrs)
        ]

        self._test_swalr(swa_scheduler, scheduler, targets, swa_start, epochs)

    def _test_swalr(self, swa_scheduler, scheduler, targets, swa_start, epochs):
        for epoch in range(epochs):
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(
                    target[epoch],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[epoch], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )
            if epoch >= swa_start:
                self.opt.step()
                swa_scheduler.step()
            elif scheduler is not None:
                self.opt.step()
                scheduler.step()

    def test_swalr_hypers(self):
        # Test that SWALR raises errors for incorrect hyper-parameters
        with self.assertRaisesRegex(ValueError, "anneal_strategy must"):
            swa_scheduler = SWALR(self.opt, anneal_strategy="exponential", swa_lr=1.0)

        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
            swa_scheduler = SWALR(self.opt, anneal_epochs=-1, swa_lr=1.0)
        with self.assertRaisesRegex(ValueError, "anneal_epochs must"):
            swa_scheduler = SWALR(self.opt, anneal_epochs=1.7, swa_lr=1.0)
        with self.assertRaisesRegex(ValueError, "swa_lr must"):
            swa_scheduler = SWALR(self.opt, swa_lr=[1.0, 0.1, 0.01])

    def test_step_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: StepLR(self.opt, gamma=0.1, step_size=3),
            lambda: StepLR(self.opt, gamma=0.01 / 2, step_size=1),
        )

    def test_multi_step_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: MultiStepLR(self.opt, gamma=0.1, milestones=[2, 5, 9]),
            lambda: MultiStepLR(self.opt, gamma=0.01, milestones=[1, 4, 6]),
        )

    def test_exp_step_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: ExponentialLR(self.opt, gamma=0.1),
            lambda: ExponentialLR(self.opt, gamma=0.01),
        )

    def test_cosine_lr_state_dict(self):
        epochs = 10
        eta_min = 1e-10
        self._check_scheduler_state_dict(
            lambda: CosineAnnealingLR(self.opt, T_max=epochs, eta_min=eta_min),
            lambda: CosineAnnealingLR(self.opt, T_max=epochs // 2, eta_min=eta_min / 2),
            epochs=epochs,
        )

    def test_reduce_lr_on_plateau_state_dict(self):
        scheduler = ReduceLROnPlateau(self.opt, mode="min", factor=0.1, patience=2)
        for score in [1.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0, 3.0, 2.0, 1.0]:
            scheduler.step(score)
        scheduler_copy = ReduceLROnPlateau(
            self.opt, mode="max", factor=0.5, patience=10
        )
        scheduler_copy.load_state_dict(scheduler.state_dict())
        for key in scheduler.__dict__.keys():
            if key not in {"optimizer", "is_better"}:
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])

    def test_lambda_lr_state_dict_fn(self):
        scheduler = LambdaLR(self.opt, lr_lambda=lambda x: x)
        state = scheduler.state_dict()
        self.assertIsNone(state["lr_lambdas"][0])

        scheduler_copy = LambdaLR(self.opt, lr_lambda=lambda x: x)
        scheduler_copy.load_state_dict(state)
        for key in scheduler.__dict__.keys():
            if key not in {"optimizer", "lr_lambdas"}:
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])

    def test_lambda_lr_state_dict_obj(self):
        scheduler = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(10))
        state = scheduler.state_dict()
        self.assertIsNotNone(state["lr_lambdas"][0])

        scheduler_copy = LambdaLR(self.opt, lr_lambda=self.LambdaLRTestObject(-1))
        scheduler_copy.load_state_dict(state)
        for key in scheduler.__dict__.keys():
            if key not in {"optimizer"}:
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])

    def test_CosineAnnealingWarmRestarts_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=10, T_mult=2),
            lambda: CosineAnnealingWarmRestarts(self.opt, T_0=100),
        )

    def test_swa_lr_state_dict(self):
        self._check_scheduler_state_dict(
            lambda: SWALR(self.opt, anneal_epochs=3, swa_lr=0.5),
            lambda: SWALR(
                self.opt, anneal_epochs=10, anneal_strategy="linear", swa_lr=5.0
            ),
        )

    def _check_scheduler_state_dict(self, constr, constr2, epochs=10):
        scheduler = constr()
        for _ in range(epochs):
            scheduler.optimizer.step()
            scheduler.step()
        scheduler_copy = constr2()
        scheduler_copy.load_state_dict(scheduler.state_dict())
        for key in scheduler.__dict__.keys():
            if key != "optimizer":
                self.assertEqual(scheduler.__dict__[key], scheduler_copy.__dict__[key])
        self.assertEqual(scheduler.get_last_lr(), scheduler_copy.get_last_lr())

    def _test_get_last_lr(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, LRScheduler):
            schedulers = [schedulers]
        optimizers = {scheduler.optimizer for scheduler in schedulers}
        for epoch in range(epochs):
            result = [scheduler.get_last_lr() for scheduler in schedulers]
            [optimizer.step() for optimizer in optimizers]
            [scheduler.step() for scheduler in schedulers]
            target = [[t[epoch] for t in targets]] * len(schedulers)
            for t, r in zip(target, result):
                self.assertEqual(
                    t,
                    r,
                    msg=f"LR is wrong in epoch {epoch}: expected {t}, got {r}",
                    atol=1e-5,
                    rtol=0,
                )

    def _test_with_epoch(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, LRScheduler):
            schedulers = [schedulers]
        optimizers = {scheduler.optimizer for scheduler in schedulers}
        for epoch in range(epochs):
            [optimizer.step() for optimizer in optimizers]
            with warnings.catch_warnings(record=True) as w:
                [
                    scheduler.step(epoch) for scheduler in schedulers
                ]  # step before assert: skip initial lr
                self._check_warning_is_epoch_deprecation_warning(
                    w, num_warnings=len(schedulers)
                )
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(
                    target[epoch],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[epoch], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )

    def _test(self, schedulers, targets, epochs=10):
        if isinstance(schedulers, LRScheduler):
            schedulers = [schedulers]
        for epoch in range(epochs):
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(
                    target[epoch],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[epoch], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )
            [scheduler.step() for scheduler in schedulers]

    def _test_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs=10):
        for index, epoch in enumerate(torch.arange(0, epochs, 0.1)):
            epoch = round(epoch.item(), 1)
            scheduler.step(epoch)
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(
                    target[index],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[index], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )

    def _test_interleaved_CosineAnnealingWarmRestarts(self, scheduler, targets, epochs):
        for index, epoch in enumerate(epochs):
            scheduler.step(epoch)
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(
                    target[index],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[index], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )

    def _test_against_closed_form(self, scheduler, closed_form_scheduler, epochs=10):
        self.setUp()
        targets = []
        for epoch in range(epochs):
            closed_form_scheduler.optimizer.step()
            with warnings.catch_warnings(record=True) as w:
                closed_form_scheduler.step(epoch)
                self._check_warning_is_epoch_deprecation_warning(w)
            targets.append([group["lr"] for group in self.opt.param_groups])
        self.setUp()
        for epoch in range(epochs):
            self.opt.step()
            scheduler.step()
            for i, param_group in enumerate(self.opt.param_groups):
                self.assertEqual(
                    targets[epoch][i],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, targets[epoch][i], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )

    def _test_reduce_lr_on_plateau(
        self, schedulers, targets, metrics, epochs=10, verbose=False
    ):
        if isinstance(schedulers, (LRScheduler, ReduceLROnPlateau)):
            schedulers = [schedulers]
        for epoch in range(epochs):
            self.opt.step()
            for scheduler in schedulers:
                if isinstance(scheduler, ReduceLROnPlateau):
                    scheduler.step(metrics[epoch])
                else:
                    scheduler.step()
            if verbose:
                print("epoch{}:\tlr={}".format(epoch, self.opt.param_groups[0]["lr"]))
            for param_group, target in zip(self.opt.param_groups, targets):
                self.assertEqual(
                    target[epoch],
                    param_group["lr"],
                    msg="LR is wrong in epoch {}: expected {}, got {}".format(
                        epoch, target[epoch], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )

    def _test_cycle_lr(
        self,
        scheduler,
        lr_targets,
        momentum_targets,
        batch_iterations,
        verbose=False,
        use_beta1=False,
    ):
        for batch_num in range(batch_iterations):
            if verbose:
                if "momentum" in self.opt.param_groups[0].keys():
                    print(
                        "batch{}:\tlr={},momentum={}".format(
                            batch_num,
                            self.opt.param_groups[0]["lr"],
                            self.opt.param_groups[0]["momentum"],
                        )
                    )
                elif use_beta1 and "betas" in self.opt.param_groups[0].keys():
                    print(
                        "batch{}:\tlr={},beta1={}".format(
                            batch_num,
                            self.opt.param_groups[0]["lr"],
                            self.opt.param_groups[0]["betas"][0],
                        )
                    )
                else:
                    print(
                        "batch{}:\tlr={}".format(
                            batch_num, self.opt.param_groups[0]["lr"]
                        )
                    )

            for param_group, lr_target, momentum_target in zip(
                self.opt.param_groups, lr_targets, momentum_targets
            ):
                self.assertEqual(
                    lr_target[batch_num],
                    param_group["lr"],
                    msg="LR is wrong in batch_num {}: expected {}, got {}".format(
                        batch_num, lr_target[batch_num], param_group["lr"]
                    ),
                    atol=1e-5,
                    rtol=0,
                )

                if use_beta1 and "betas" in param_group.keys():
                    self.assertEqual(
                        momentum_target[batch_num],
                        param_group["betas"][0],
                        msg="Beta1 is wrong in batch_num {}: expected {}, got {}".format(
                            batch_num,
                            momentum_target[batch_num],
                            param_group["betas"][0],
                        ),
                        atol=1e-5,
                        rtol=0,
                    )
                elif "momentum" in param_group.keys():
                    self.assertEqual(
                        momentum_target[batch_num],
                        param_group["momentum"],
                        msg="Momentum is wrong in batch_num {}: expected {}, got {}".format(
                            batch_num,
                            momentum_target[batch_num],
                            param_group["momentum"],
                        ),
                        atol=1e-5,
                        rtol=0,
                    )
            self.opt.step()
            scheduler.step()

    def test_cosine_then_cyclic(self):
        # https://github.com/pytorch/pytorch/issues/21965

        max_lr = 0.3
        base_lr = 0.1
        optim_lr = 0.5

        model = torch.nn.Linear(2, 1)
        optimizer = SGD(model.parameters(), lr=optim_lr)
        lr_scheduler_1 = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=20, eta_min=0.1
        )
        lr_scheduler_2 = torch.optim.lr_scheduler.CyclicLR(
            optimizer, base_lr=base_lr, max_lr=max_lr, step_size_up=1, step_size_down=3
        )

        for i in range(40):
            optimizer.step()
            if i <= lr_scheduler_1.T_max:
                lr_scheduler_1.step()
            else:
                lr_scheduler_2.step()
            last_lr = optimizer.param_groups[0]["lr"]

        self.assertLessEqual(last_lr, max_lr)

    @parametrize(
        "LRClass",
        [
            partial(LambdaLR, lr_lambda=lambda e: e // 10),
            partial(MultiplicativeLR, lr_lambda=lambda: 0.95),
            partial(StepLR, step_size=30),
            partial(MultiStepLR, milestones=[30, 80]),
            ConstantLR,
            LinearLR,
            partial(ExponentialLR, gamma=0.9),
            lambda opt, **kwargs: SequentialLR(
                opt,
                schedulers=[ConstantLR(opt), ConstantLR(opt)],
                milestones=[2],
                **kwargs,
            ),
            PolynomialLR,
            partial(CosineAnnealingLR, T_max=10),
            ReduceLROnPlateau,
            partial(CyclicLR, base_lr=0.01, max_lr=0.1),
            partial(CosineAnnealingWarmRestarts, T_0=20),
            partial(OneCycleLR, max_lr=0.01, total_steps=10),
        ],
    )
    def test_lr_scheduler_verbose_deprecation_warning(self, LRClass):
        """Check that a deprecating warning with verbose parameter."""
        with self.assertWarnsOnceRegex(
            UserWarning, "The verbose parameter is deprecated"
        ):
            LRClass(self.opt, verbose=True)

        with self.assertWarnsOnceRegex(
            UserWarning, "The verbose parameter is deprecated"
        ):
            LRClass(self.opt, verbose=False)

        # No warning is raised when verbose is the default value.
        with warnings.catch_warnings():
            warnings.simplefilter("error", UserWarning)
            LRClass(self.opt)

    @parametrize(
        "LRClass",
        [
            partial(LambdaLR, lr_lambda=lambda e: e // 10),
            partial(MultiplicativeLR, lr_lambda=lambda: 0.95),
            partial(StepLR, step_size=30),
            partial(MultiStepLR, milestones=[30, 80]),
            ConstantLR,
            LinearLR,
            partial(ExponentialLR, gamma=0.9),
            PolynomialLR,
            partial(CosineAnnealingLR, T_max=10),
            lambda opt, **kwargs: ChainedScheduler(
                schedulers=[ConstantLR(opt), ConstantLR(opt)], **kwargs
            ),
            lambda opt, **kwargs: SequentialLR(
                opt,
                schedulers=[ConstantLR(opt), ConstantLR(opt)],
                milestones=[2],
                **kwargs,
            ),
            ReduceLROnPlateau,
            partial(CyclicLR, base_lr=0.01, max_lr=0.1),
            partial(OneCycleLR, max_lr=0.01, total_steps=10, anneal_strategy="linear"),
            partial(CosineAnnealingWarmRestarts, T_0=20),
        ],
    )
    @parametrize("weights_only", [True, False])
    def test_lr_scheduler_state_dict_load(self, LRClass, weights_only):
        scheduler = LRClass(self.opt)
        state_dict = scheduler.state_dict()

        with tempfile.TemporaryFile() as f:
            torch.save(state_dict, f)
            f.seek(0)
            state_dict_loaded = torch.load(f, weights_only=weights_only)
            self.assertEqual(state_dict, state_dict_loaded)
            # Make sure state_dict can be loaded
            scheduler2 = LRClass(self.opt)
            scheduler2.load_state_dict(state_dict_loaded)
            self.assertEqual(scheduler2.state_dict(), state_dict)

    @parametrize(
        "LRClass",
        [
            partial(LambdaLR, lr_lambda=lambda e: e // 10),
            partial(MultiplicativeLR, lr_lambda=lambda e: 0.95),
            partial(StepLR, step_size=30),
            partial(MultiStepLR, milestones=[30, 80]),
            ConstantLR,
            LinearLR,
            partial(ExponentialLR, gamma=0.9),
            PolynomialLR,
            partial(CosineAnnealingLR, T_max=10),
            partial(CosineAnnealingWarmRestarts, T_0=20),
        ],
    )
    def test_constant_initial_lr(self, LRClass):
        # Test that the initial learning rate is constant
        lr = torch.as_tensor(0.1)
        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
        sch = LRClass(opt)

        ori_param_groups = copy.deepcopy(opt.param_groups)

        for i in range(2):
            opt.step()
            sch.step(i)
            lr.multiply_(0.1)
            for group, ori_group in zip(opt.param_groups, ori_param_groups):
                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
                self.assertEqual(sch.base_lrs, [0.1])

    def test_constant_initial_params_cyclelr(self):
        # Test that the initial learning rate is constant
        lr = torch.as_tensor(0.1)
        max_lr = torch.as_tensor(0.2)
        base_momentum = torch.as_tensor(0.8)
        max_momentum = torch.as_tensor(0.9)
        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
        sch = CyclicLR(
            opt,
            base_lr=lr,
            max_lr=max_lr,
            base_momentum=base_momentum,
            max_momentum=max_momentum,
        )
        ori_param_groups = copy.deepcopy(opt.param_groups)

        for i in range(2):
            lr.multiply_(0.5)
            max_lr.multiply_(0.5)
            base_momentum.multiply_(0.5)
            max_momentum.multiply_(0.5)
            opt.step()
            sch.step(i)
            for group, ori_group in zip(opt.param_groups, ori_param_groups):
                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
                self.assertEqual(group["max_momentum"], ori_group["max_momentum"])
                self.assertEqual(group["base_momentum"], ori_group["base_momentum"])
                self.assertEqual(sch.base_lrs, [0.1])
                self.assertEqual(sch.max_lrs, [0.2])
                self.assertEqual(group["max_momentum"], 0.9)
                self.assertEqual(group["base_momentum"], 0.8)

    def test_constant_initial_params_onecyclelr(self):
        # Test that the initial learning rate is constant
        lr = torch.as_tensor(0.1)
        base_momentum = torch.as_tensor(0.85)
        max_momentum = torch.as_tensor(0.95)
        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
        sch = OneCycleLR(
            opt,
            max_lr=lr,
            total_steps=10,
            base_momentum=base_momentum,
            max_momentum=max_momentum,
        )
        ori_param_groups = copy.deepcopy(opt.param_groups)

        for i in range(2):
            lr.multiply_(0.5)
            base_momentum.multiply_(0.5)
            max_momentum.multiply_(0.5)
            opt.step()
            sch.step(i)

            for group, ori_group in zip(opt.param_groups, ori_param_groups):
                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
                self.assertEqual(group["max_lr"], ori_group["max_lr"])
                self.assertEqual(group["min_lr"], ori_group["min_lr"])
                self.assertEqual(group["max_momentum"], ori_group["max_momentum"])
                self.assertEqual(group["base_momentum"], ori_group["base_momentum"])
                self.assertEqual(group["max_momentum"], 0.95)
                self.assertEqual(group["base_momentum"], 0.85)

    def test_constant_initial_params_swalr(self):
        # Test that the initial learning rate is constant
        lr = torch.as_tensor(0.1)
        swa_lr = torch.as_tensor(0.05)
        opt = SGD([torch.nn.Parameter(torch.randn(1))], lr=lr)
        sch = SWALR(opt, swa_lr=swa_lr)
        ori_param_groups = copy.deepcopy(opt.param_groups)

        for i in range(2):
            lr.multiply_(0.5)
            swa_lr.multiply_(0.5)
            opt.step()
            sch.step()
            for group, ori_group in zip(opt.param_groups, ori_param_groups):
                self.assertEqual(group["initial_lr"], ori_group["initial_lr"])
                self.assertEqual(group["swa_lr"], ori_group["swa_lr"])
                self.assertEqual(group["swa_lr"], 0.05)
                self.assertEqual(sch.base_lrs, [0.1])


instantiate_parametrized_tests(TestLRScheduler)


if __name__ == "__main__":
    print("These tests should be run through test/test_optim.py instead")
