# Owner(s): ["oncall: distributed"]

import unittest
from typing import List, Optional, Tuple

import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.optim import Adam, AdamW, SGD
from torch.testing._internal.common_utils import run_tests, TestCase


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        torch.manual_seed(0)
        self.lin1 = nn.Linear(3, 3, bias=False)
        self.lin2 = nn.Linear(3, 3, bias=False)

    def forward(self, t1):
        return self.lin2(F.relu(self.lin1(t1)))


# dummy class to showcase custom optimizer registration with functional wrapper
class MyDummyFnOptimizer:
    def __init__(
        self,
        params: List[Tensor],
        lr: float = 1e-3,
        betas: Tuple[float, float] = (0.9, 0.999),
        eps: float = 1e-6,
        weight_decay: float = 0.0,
        _allow_empty_param_list: bool = False,
    ):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
        if not 0.0 < weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")

        self.defaults = {
            "lr": lr,
            "eps": eps,
            "beta1": betas[0],
            "beta2": betas[1],
            "weight_decay": weight_decay,
        }

        if len(params) == 0 and not _allow_empty_param_list:
            raise ValueError("optimizer got an empty parameter list")

    def step_param(self, param: Tensor, grad: Optional[Tensor]):
        # call the custom optimizer step_param implementation
        with torch.no_grad():
            raise RuntimeError(
                "MyDummyFnOptimizer does not support step_param() as of now"
            )

    def step(self, gradients: List[Optional[Tensor]]):
        # call the custom optimizer step implementation
        with torch.no_grad():
            raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")


if torch.distributed.is_available():
    from torch.distributed.optim.utils import (
        functional_optim_map,
        register_functional_optim,
    )


@unittest.skipIf(
    not torch.distributed.is_available(), "These are testing distributed functions"
)
class TestFunctionalOptimParity(TestCase):
    def _validate_parameters(self, params_1, params_2):
        for p1, p2 in zip(params_1, params_2):
            self.assertEqual(p1, p2)

    # Dynamo fails at compiling this for python 3.8/3.11
    # Since it passes while compiling the actual code under test
    # we disable dynamo here.
    @torch._disable_dynamo(recursive=False)
    def _test_functional_optim_parity(self, optim_cls, *args, **kwargs):
        module_optim = MyModule()
        module_functional = MyModule()
        optim_params = module_optim.parameters()
        functional_params = module_functional.parameters()
        optim = optim_cls(optim_params, *args, **kwargs)
        functional_optim_cls = functional_optim_map.get(optim_cls, None)
        if not functional_optim_cls:
            raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
        optim_functional = functional_optim_cls(
            [], *args, **kwargs, _allow_empty_param_list=True
        )
        if not hasattr(optim_functional, "step_param"):
            raise ValueError(
                f"Functional optimizer class {optim_functional} must implement step_param method."
            )

        # Initial weights should match
        self._validate_parameters(
            module_optim.parameters(), module_functional.parameters()
        )
        # Save old parameters to verify optimizer modifies them.
        old_module_optim_params = [
            param.clone().detach() for param in module_optim.parameters()
        ]
        old_module_functional_params = [
            param.clone().detach() for param in module_functional.parameters()
        ]

        t1 = torch.randn(3, 3)
        for _ in range(10):
            module_optim.zero_grad()
            module_functional.zero_grad()
            # Forward + Backward
            optim_out = module_optim(t1).sum()
            functional_out = module_functional(t1).sum()
            optim_out.backward()
            functional_out.backward()
            # Optimizer step
            optim.step()
            # Functional optimizer step_param
            for param in module_functional.parameters():
                grad = param.grad
                optim_functional.step_param(param, grad)

            # Validate parameters are equal
            for optim_param, functional_param in zip(
                module_optim.parameters(), module_functional.parameters()
            ):
                self.assertEqual(optim_param, functional_param)
            # Validate parameters are modified.
            for i, (optim_param, functional_param) in enumerate(
                zip(module_optim.parameters(), module_functional.parameters())
            ):
                self.assertNotEqual(old_module_optim_params[i], optim_param)
                self.assertNotEqual(old_module_functional_params[i], functional_param)

    def _test_functional_optim_registration(self):
        fn_map_key = "MyDummyFnOptimizer"
        fn_optim = MyDummyFnOptimizer
        register_functional_optim(fn_map_key, fn_optim)
        functional_optim_cls = functional_optim_map.get(fn_map_key, None)
        if not functional_optim_cls:
            raise ValueError(f"Functional optimizer not registered for {fn_map_key}")

    def test_functional_optim_registration(self):
        self._test_functional_optim_registration()

    def test_functional_optim_parity_sgd(self):
        self._test_functional_optim_parity(SGD, 1e-2, momentum=0.9, weight_decay=0.01)

    def test_functional_optim_parity_adam(self):
        self._test_functional_optim_parity(Adam, 1e-2, betas=(0.9, 0.999), eps=1e-6)

    def test_functional_optim_parity_adam_w(self):
        self._test_functional_optim_parity(AdamW, 1e-2, betas=(0.9, 0.999), eps=1e-6)


if __name__ == "__main__":
    run_tests()
