# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from typing import Tuple

import torch
from torch.ao.quantization.fake_quantize import _is_symmetric_quant
from torch.ao.quantization.utils import is_per_tensor
from torch.quantization import FakeQuantize
from torch.quantization.observer import MinMaxObserver


class AdaroundFakeQuantizer(FakeQuantize):
    """
    This is a FakeQuantizer that enables an adaptive rounding fake quantizer.
    Adaround is a technique to adaptively round weights, derived from the paper https://arxiv.org/pdf/2004.10568.pdf
    For HTP compatibility, we are targeting to use symmetric quantization
    """

    scale: torch.Tensor
    zero_point: torch.Tensor
    V: torch.nn.Parameter

    # pyre-fixme[3]: Return type must be annotated.
    def __init__(
        self,
        observer=MinMaxObserver,
        qscheme=torch.per_tensor_symmetric,  # not used, but needed for fakequant
        quant_min: int = -128,
        quant_max: int = 127,
        ch_axis: int = 0,
        # pyre-fixme[2]: Parameter must be annotated.
        **observer_kwargs,
    ):
        super().__init__(
            observer=observer,
            qscheme=qscheme,
            quant_min=quant_min,
            quant_max=quant_max,
            is_dynamic=False,
            **observer_kwargs,
        )
        # Populate quant_min/quant_max to observer_kwargs if valid
        if quant_min is not None and quant_max is not None:
            assert (
                quant_min <= quant_max
            ), "quant_min must be less than or equal to quant_max"
        # pyre-fixme[4]: Attribute must be annotated.
        self.qscheme = qscheme
        self.is_per_tensor: bool = is_per_tensor(qscheme)
        self.is_symmetric: bool = _is_symmetric_quant(qscheme)
        assert self.is_symmetric, "Only symmetric quantization is supported"
        self.ch_axis: int = ch_axis

        self.scale = torch.tensor([], requires_grad=False)
        self.zero_point = torch.tensor([], requires_grad=False)
        self.V = torch.nn.Parameter(torch.tensor([]), requires_grad=True)
        # Fixed Stretch parameters
        self.zeta: torch.Tensor = torch.tensor(1.1, requires_grad=False)
        self.gamma: torch.Tensor = torch.tensor(-0.1, requires_grad=False)
        self.sigmoid = torch.nn.Sigmoid()
        self.use_soft_rounding = True

    @torch.jit.export
    def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.scale, self.zero_point

    @torch.jit.export
    def extra_repr(self) -> str:
        return (
            f"fake_quant_enabled={self.fake_quant_enabled}, observer_enabled={self.observer_enabled}, "
            f"quant_min={self.activation_post_process.quant_min}, quant_max={self.activation_post_process.quant_max}, "
            f"dtype={self.dtype}, qscheme={self.qscheme}, ch_axis={self.ch_axis}, "
            f"scale={self.scale}, zero_point={self.zero_point}, (self.V >= 0).int().sum()={(self.V >= 0).int().sum()}"
        )

    def enable_weight_fake_quant(self) -> None:
        self.fake_quant_enabled[0] = 1

    def get_rectified_sigmoid_func(self) -> torch.Tensor:
        if self.use_soft_rounding:
            return torch.clamp(
                self.sigmoid(self.V) * (self.zeta - self.gamma) + self.gamma,
                min=0,
                max=1,
            )
        else:
            # This will dump a binary solution
            return (self.V >= 0).int()

    @torch.jit.ignore
    def update_scale(
        self, X: torch.Tensor, _scale: torch.Tensor, _zero_point: torch.Tensor
    ) -> None:
        if self.scale.numel() == 0:
            self.scale.data = _scale.to(X.device)
            self.zero_point = _zero_point.to(X.device)
        else:
            self.scale.data = _scale
            if not self.is_symmetric:
                self.zero_point = _zero_point
            else:
                self.zero_point = torch.zeros_like(_zero_point)
            for i in range(X.dim()):
                if i == self.ch_axis:
                    continue
                self.zero_point = self.zero_point.unsqueeze(i)
        X_q = X / self.scale
        X_q_floor = torch.floor(X_q)
        residual = X_q - X_q_floor  # [0,1)
        assert torch.all(
            torch.ge(residual, 0)
        ), "residual should be non-negative [0, 1)"
        V_init = -torch.log((self.zeta - self.gamma) / (residual - self.gamma) - 1)
        self.V.data = V_init

    def forward(self, X: torch.Tensor) -> torch.Tensor:
        if self.observer_enabled[0] == 1:
            X_detached = X.detach()
            self.activation_post_process(X_detached)
            _scale, _zero_point = self.activation_post_process.calculate_qparams()
            _scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(
                self.zero_point.device
            )
            dims = list(range(X.dim()))
            if not self.is_per_tensor:
                dims.remove(self.ch_axis)
            if not self.is_per_tensor:
                for i in range(X.dim()):
                    if i == self.ch_axis:
                        continue
                    _scale = _scale.unsqueeze(i)
                    _zero_point = _zero_point.unsqueeze(i)
            self.update_scale(X_detached, _scale, _zero_point)

        if self.fake_quant_enabled[0] == 1:
            # Perform soft quantization
            # See the equation (23) in Adaround paper
            h_v = self.get_rectified_sigmoid_func()
            X_q = X / self.scale
            # Straight-Through Estimator for floor function
            X_q_floor = torch.floor(X_q) + self.zero_point
            # Regardless of rounding, gradient should be able to flow back to self.V from X_q_dq.
            # With adaround, we don't train weight, but train V only.
            X_q_dq = (
                torch.clamp(X_q_floor + h_v, min=self.quant_min, max=self.quant_max)
                - self.zero_point
            ) * self.scale
            return X_q_dq
        else:
            return X
