# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

# Helper functions for tranforming the model to be able to run SpinQuant.
# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant.


import torch

import torch.nn.functional as F

from executorch.examples.models.llama.llama_transformer import FeedForward
from torch import nn


def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module):
    """
    SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
    R3 needs to be injected as well when KV cache quantization is enabled.
    """
    try:
        from fast_hadamard_transform import hadamard_transform
    except ImportError:
        raise ImportError(
            "Please install fast-hadamard-transform: pip install fast-hadamard-transform"
        )

    class FeedForwardCudaCustom(nn.Module):
        def __init__(self, w1, w2, w3):
            super().__init__()
            self.w1 = w1
            self.w2 = w2
            self.w3 = w3

        def forward(self, x):
            w = F.silu(self.w1(x)) * self.w3(x)
            n = w.shape[-1]
            return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt())

    for name, child in module.named_children():
        if isinstance(child, FeedForward):
            setattr(module, name, FeedForwardCudaCustom(child.w1, child.w2, child.w3))
        else:
            _inject_fast_hadamard_transform_cuda_for_spin_quant(child)


def inject_fast_hadamard_transform_cuda_for_spin_quant(
    module: torch.nn.Module,
) -> torch.nn.Module:
    _inject_fast_hadamard_transform_cuda_for_spin_quant(module)
    return module


def _inject_fast_hadamard_transform_native_for_spin_quant(module: torch.nn.Module):
    """
    SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer.
    R3 needs to be injected as well when KV cache quantization is enabled.
    """

    class FeedForwardNativeCustom(nn.Module):
        def __init__(self, w1, w2, w3):
            super().__init__()
            self.w1 = w1
            self.w2 = w2
            self.w3 = w3

        def forward(self, x):
            return self.w2(
                torch.ops.llama.fast_hadamard_transform(F.silu(self.w1(x)) * self.w3(x))
            )

    for name, child in module.named_children():
        if isinstance(child, FeedForward):
            setattr(module, name, FeedForwardNativeCustom(child.w1, child.w2, child.w3))
        else:
            _inject_fast_hadamard_transform_native_for_spin_quant(child)


def inject_fast_hadamard_transform_native_for_spin_quant(
    module: torch.nn.Module,
) -> torch.nn.Module:
    _inject_fast_hadamard_transform_native_for_spin_quant(module)
    return module
