# Owner(s): ["module: optimizer"]

import itertools
import pickle

import torch
from torch.optim.swa_utils import (
    AveragedModel,
    get_ema_multi_avg_fn,
    get_swa_multi_avg_fn,
    update_bn,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    load_tests,
    parametrize,
    TestCase,
)


# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests


class TestSWAUtils(TestCase):
    class SWATestDNN(torch.nn.Module):
        def __init__(self, input_features):
            super().__init__()
            self.n_features = 100
            self.fc1 = torch.nn.Linear(input_features, self.n_features)
            self.bn = torch.nn.BatchNorm1d(self.n_features)

        def compute_preactivation(self, x):
            return self.fc1(x)

        def forward(self, x):
            x = self.fc1(x)
            x = self.bn(x)
            return x

    class SWATestCNN(torch.nn.Module):
        def __init__(self, input_channels):
            super().__init__()
            self.n_features = 10
            self.conv1 = torch.nn.Conv2d(
                input_channels, self.n_features, kernel_size=3, padding=1
            )
            self.bn = torch.nn.BatchNorm2d(self.n_features, momentum=0.3)

        def compute_preactivation(self, x):
            return self.conv1(x)

        def forward(self, x):
            x = self.conv1(x)
            x = self.bn(x)
            return x

    def _test_averaged_model(self, net_device, swa_device, ema):
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2),
            torch.nn.BatchNorm2d(5, momentum=0.3),
            torch.nn.Conv2d(5, 2, kernel_size=3),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 5),
            torch.nn.ReLU(),
            torch.nn.Linear(5, 10),
        ).to(net_device)

        averaged_params, averaged_dnn = self._run_averaged_steps(dnn, swa_device, ema)

        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
            self.assertEqual(p_avg, p_swa)
            # Check that AveragedModel is on the correct device
            self.assertTrue(p_swa.device == swa_device)
            self.assertTrue(p_avg.device == net_device)
        self.assertTrue(averaged_dnn.n_averaged.device == swa_device)

    def _run_averaged_steps(self, dnn, swa_device, ema):
        ema_decay = 0.999
        if ema:
            averaged_dnn = AveragedModel(
                dnn, device=swa_device, multi_avg_fn=get_ema_multi_avg_fn(ema_decay)
            )
        else:
            averaged_dnn = AveragedModel(
                dnn, device=swa_device, multi_avg_fn=get_swa_multi_avg_fn()
            )

        averaged_params = [torch.zeros_like(param) for param in dnn.parameters()]

        n_updates = 10
        for i in range(n_updates):
            for p, p_avg in zip(dnn.parameters(), averaged_params):
                p.detach().add_(torch.randn_like(p))
                if ema:
                    p_avg += (
                        p.detach()
                        * ema_decay ** (n_updates - i - 1)
                        * ((1 - ema_decay) if i > 0 else 1.0)
                    )
                else:
                    p_avg += p.detach() / n_updates
            averaged_dnn.update_parameters(dnn)

        return averaged_params, averaged_dnn

    @parametrize("ema", [True, False])
    def test_averaged_model_all_devices(self, ema):
        cpu = torch.device("cpu")
        self._test_averaged_model(cpu, cpu, ema)
        if torch.cuda.is_available():
            cuda = torch.device(0)
            self._test_averaged_model(cuda, cpu, ema)
            self._test_averaged_model(cpu, cuda, ema)
            self._test_averaged_model(cuda, cuda, ema)

    @parametrize("ema", [True, False])
    def test_averaged_model_mixed_device(self, ema):
        if not torch.cuda.is_available():
            return
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
        )
        dnn[0].cuda()
        dnn[1].cpu()

        averaged_params, averaged_dnn = self._run_averaged_steps(dnn, None, ema)

        for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
            self.assertEqual(p_avg, p_swa)
            # Check that AveragedModel is on the correct device
            self.assertTrue(p_avg.device == p_swa.device)

    def test_averaged_model_state_dict(self):
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3), torch.nn.Linear(5, 10)
        )
        averaged_dnn = AveragedModel(dnn)
        averaged_dnn2 = AveragedModel(dnn)
        n_updates = 10
        for i in range(n_updates):
            for p in dnn.parameters():
                p.detach().add_(torch.randn_like(p))
            averaged_dnn.update_parameters(dnn)
        averaged_dnn2.load_state_dict(averaged_dnn.state_dict())
        for p_swa, p_swa2 in zip(averaged_dnn.parameters(), averaged_dnn2.parameters()):
            self.assertEqual(p_swa, p_swa2)
        self.assertTrue(averaged_dnn.n_averaged == averaged_dnn2.n_averaged)

    def test_averaged_model_default_avg_fn_picklable(self):
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.BatchNorm2d(5),
            torch.nn.Linear(5, 5),
        )
        averaged_dnn = AveragedModel(dnn)
        pickle.dumps(averaged_dnn)

    @parametrize("use_multi_avg_fn", [True, False])
    @parametrize("use_buffers", [True, False])
    def test_averaged_model_exponential(self, use_multi_avg_fn, use_buffers):
        # Test AveragedModel with EMA as avg_fn and use_buffers as True.
        dnn = torch.nn.Sequential(
            torch.nn.Conv2d(1, 5, kernel_size=3),
            torch.nn.BatchNorm2d(5, momentum=0.3),
            torch.nn.Linear(5, 10),
        )
        decay = 0.9

        if use_multi_avg_fn:
            averaged_dnn = AveragedModel(
                dnn, multi_avg_fn=get_ema_multi_avg_fn(decay), use_buffers=use_buffers
            )
        else:

            def avg_fn(p_avg, p, n_avg):
                return decay * p_avg + (1 - decay) * p

            averaged_dnn = AveragedModel(dnn, avg_fn=avg_fn, use_buffers=use_buffers)

        if use_buffers:
            dnn_params = list(itertools.chain(dnn.parameters(), dnn.buffers()))
        else:
            dnn_params = list(dnn.parameters())

        averaged_params = [
            torch.zeros_like(param)
            for param in dnn_params
            if param.size() != torch.Size([])
        ]

        n_updates = 10
        for i in range(n_updates):
            updated_averaged_params = []
            for p, p_avg in zip(dnn_params, averaged_params):
                if p.size() == torch.Size([]):
                    continue
                p.detach().add_(torch.randn_like(p))
                if i == 0:
                    updated_averaged_params.append(p.clone())
                else:
                    updated_averaged_params.append(
                        (p_avg * decay + p * (1 - decay)).clone()
                    )
            averaged_dnn.update_parameters(dnn)
            averaged_params = updated_averaged_params

        if use_buffers:
            for p_avg, p_swa in zip(
                averaged_params,
                itertools.chain(
                    averaged_dnn.module.parameters(), averaged_dnn.module.buffers()
                ),
            ):
                self.assertEqual(p_avg, p_swa)
        else:
            for p_avg, p_swa in zip(averaged_params, averaged_dnn.parameters()):
                self.assertEqual(p_avg, p_swa)
            for b_avg, b_swa in zip(dnn.buffers(), averaged_dnn.module.buffers()):
                self.assertEqual(b_avg, b_swa)

    def _test_update_bn(self, dnn, dl_x, dl_xy, cuda):
        preactivation_sum = torch.zeros(dnn.n_features)
        preactivation_squared_sum = torch.zeros(dnn.n_features)
        if cuda:
            preactivation_sum = preactivation_sum.cuda()
            preactivation_squared_sum = preactivation_squared_sum.cuda()
        total_num = 0
        for x in dl_x:
            x = x[0]
            if cuda:
                x = x.cuda()

            dnn.forward(x)
            preactivations = dnn.compute_preactivation(x)
            if len(preactivations.shape) == 4:
                preactivations = preactivations.transpose(1, 3)
            preactivations = preactivations.contiguous().view(-1, dnn.n_features)
            total_num += preactivations.shape[0]

            preactivation_sum += torch.sum(preactivations, dim=0)
            preactivation_squared_sum += torch.sum(preactivations**2, dim=0)

        preactivation_mean = preactivation_sum / total_num
        preactivation_var = preactivation_squared_sum / total_num
        preactivation_var = preactivation_var - preactivation_mean**2

        update_bn(dl_xy, dnn, device=x.device)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)

        def _reset_bn(module):
            if issubclass(module.__class__, torch.nn.modules.batchnorm._BatchNorm):
                module.running_mean = torch.zeros_like(module.running_mean)
                module.running_var = torch.ones_like(module.running_var)

        # reset batch norm and run update_bn again
        dnn.apply(_reset_bn)
        update_bn(dl_xy, dnn, device=x.device)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)
        # using the dl_x loader instead of dl_xy
        dnn.apply(_reset_bn)
        update_bn(dl_x, dnn, device=x.device)
        self.assertEqual(preactivation_mean, dnn.bn.running_mean)
        self.assertEqual(preactivation_var, dnn.bn.running_var, atol=1e-1, rtol=0)

    def test_update_bn_dnn(self):
        # Test update_bn for a fully-connected network with BatchNorm1d
        objects, input_features = 100, 5
        x = torch.rand(objects, input_features)
        y = torch.rand(objects)
        ds_x = torch.utils.data.TensorDataset(x)
        ds_xy = torch.utils.data.TensorDataset(x, y)
        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
        dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
        dnn = self.SWATestDNN(input_features=input_features)
        dnn.train()
        self._test_update_bn(dnn, dl_x, dl_xy, False)
        if torch.cuda.is_available():
            dnn = self.SWATestDNN(input_features=input_features)
            dnn.train()
            self._test_update_bn(dnn.cuda(), dl_x, dl_xy, True)
        self.assertTrue(dnn.training)

    def test_update_bn_cnn(self):
        # Test update_bn for convolutional network and BatchNorm2d
        objects = 100
        input_channels = 3
        height, width = 5, 5
        x = torch.rand(objects, input_channels, height, width)
        y = torch.rand(objects)
        ds_x = torch.utils.data.TensorDataset(x)
        ds_xy = torch.utils.data.TensorDataset(x, y)
        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
        dl_xy = torch.utils.data.DataLoader(ds_xy, batch_size=5, shuffle=True)
        cnn = self.SWATestCNN(input_channels=input_channels)
        cnn.train()
        self._test_update_bn(cnn, dl_x, dl_xy, False)
        if torch.cuda.is_available():
            cnn = self.SWATestCNN(input_channels=input_channels)
            cnn.train()
            self._test_update_bn(cnn.cuda(), dl_x, dl_xy, True)
        self.assertTrue(cnn.training)

    def test_bn_update_eval_momentum(self):
        # check that update_bn preserves eval mode
        objects = 100
        input_channels = 3
        height, width = 5, 5
        x = torch.rand(objects, input_channels, height, width)
        ds_x = torch.utils.data.TensorDataset(x)
        dl_x = torch.utils.data.DataLoader(ds_x, batch_size=5, shuffle=True)
        cnn = self.SWATestCNN(input_channels=input_channels)
        cnn.eval()
        update_bn(dl_x, cnn)
        self.assertFalse(cnn.training)

        # check that momentum is preserved
        self.assertEqual(cnn.bn.momentum, 0.3)


instantiate_parametrized_tests(TestSWAUtils)


if __name__ == "__main__":
    print("These tests should be run through test/test_optim.py instead")
