# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import tempfile
import unittest

import torch
from executorch.exir import EdgeCompileConfig, to_edge

from executorch.extension.llm.modules.attention import (
    MultiHeadAttention as ETMultiHeadAttention,
)
from executorch.runtime import Runtime
from torch._inductor.package import load_package, package_aoti
from torch.testing import assert_close
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention


class AttentionTest(unittest.TestCase):
    def setUp(self):
        super().setUp()
        torch.manual_seed(0)
        # Constants
        self.embed_dim = 2048
        self.num_heads = 8
        self.num_kv_heads = 8
        self.head_dim = 64
        self.max_seq_len = 128
        self.rope_base = 500_000
        self.scale_factor = 32

        # Module dependency injections.
        self.q_proj = torch.nn.Linear(
            self.embed_dim, self.num_heads * self.head_dim, bias=False
        )
        self.k_proj = torch.nn.Linear(
            self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
        )
        self.k_proj.weight.requires_grad = False
        self.v_proj = torch.nn.Linear(
            self.embed_dim, self.num_kv_heads * self.head_dim, bias=False
        )
        self.v_proj.weight.requires_grad = False
        self.output_proj = torch.nn.Linear(
            self.num_heads * self.head_dim, self.embed_dim, bias=False
        )
        self.pos_embeddings = Llama3ScaledRoPE(
            dim=self.head_dim,
            max_seq_len=self.max_seq_len,
            base=self.rope_base,
            scale_factor=self.scale_factor,
        )

        # Original TorchTune reference module to test accuracy against.
        self.tt_mha = TTMultiHeadAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            num_kv_heads=self.num_kv_heads,
            head_dim=self.head_dim,
            q_proj=self.q_proj,
            k_proj=self.k_proj,
            v_proj=self.v_proj,
            output_proj=self.output_proj,
            pos_embeddings=self.pos_embeddings,
            max_seq_len=self.max_seq_len,
        )

        # Source transformed module that we are testing.
        self.et_mha = ETMultiHeadAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            num_kv_heads=self.num_kv_heads,
            head_dim=self.head_dim,
            q_proj=self.q_proj,
            k_proj=self.k_proj,
            v_proj=self.v_proj,
            output_proj=self.output_proj,
            pos_embeddings=self.pos_embeddings,
            max_seq_len=self.max_seq_len,
        )
        self.et_mha.load_state_dict(self.tt_mha.state_dict())
        # Common inputs.
        seq_len = 10
        self.x = torch.randn(1, seq_len, self.embed_dim)
        self.input_pos = torch.arange(seq_len).unsqueeze(0)  # shape [1, seq_len]
        seq_len_dim = torch.export.Dim("seq_len", min=1, max=100)
        self.dynamic_shapes = (
            {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
            {0: torch.export.Dim.STATIC, 1: seq_len_dim, 2: torch.export.Dim.STATIC},
            {0: torch.export.Dim.STATIC, 1: seq_len_dim},
        )
        self.causal_mask = torch.tril(
            torch.ones(
                size=(self.max_seq_len, self.max_seq_len),
                dtype=torch.bool,
            )
        )

    def test_attention_eager(self):
        et_res = self.et_mha(self.x, self.x)  # Self attention.
        tt_res = self.tt_mha(self.x, self.x)  # Self attention.

        assert_close(et_res, tt_res)

        # test with kv cache
        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)

        et_res = self.et_mha(self.x, self.x)  # Self attention.
        tt_res = self.tt_mha(self.x, self.x)  # Self attention.

        self.assertTrue(torch.allclose(et_res, tt_res))
        self.et_mha.reset_cache()
        self.tt_mha.reset_cache()

        et_res = self.et_mha(
            self.x, self.x, input_pos=self.input_pos
        )  # Self attention with input pos.
        tt_res = self.tt_mha(
            self.x, self.x, input_pos=self.input_pos
        )  # Self attention with input pos.

        self.assertTrue(torch.allclose(et_res, tt_res))

        # test kv cache read. Input pos can be [10, 11, ..., 19]
        next_input_pos = torch.arange(10, 20).unsqueeze(0)
        et_res = self.et_mha(
            self.x, self.x, input_pos=next_input_pos
        )  # Self attention with input pos.
        tt_res = self.tt_mha(
            self.x, self.x, input_pos=next_input_pos
        )  # Self attention with input pos.

        assert_close(et_res, tt_res)

    def test_attention_export(self):
        # Self attention.

        # test with kv cache
        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
        with torch.no_grad():
            et_mha_ep = torch.export.export(
                self.et_mha,
                (self.x, self.x),
                kwargs={"input_pos": self.input_pos},
                dynamic_shapes=self.dynamic_shapes,
            )
        et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
        tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

        assert_close(et_res, tt_res)

    @unittest.skipIf(
        int(os.getenv("RUN_SKIPPED", 0)) < 1, reason="TODO(T207740932): test is flaky"
    )
    def test_attention_aoti(self):
        # Self attention.

        # test with kv cache
        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
        with torch.no_grad():
            so = torch._export.aot_compile(
                self.et_mha,
                args=(self.x, self.x),
                kwargs={"input_pos": self.input_pos},
                options={
                    "aot_inductor.package": True,
                    "reorder_for_peak_memory": False,
                },
                dynamic_shapes=self.dynamic_shapes,
            )
        with tempfile.TemporaryDirectory() as tempdir:
            path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
            mha_aoti = load_package(path)

            aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
            tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
            assert_close(aoti_res, tt_res)

    def test_attention_executorch(self):
        # Self attention.
        # TODO: Fix kv cache
        # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
        # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)

        with torch.no_grad():
            et_mha_ep = torch.export.export(
                self.et_mha,
                (self.x, self.x),
                kwargs={"input_pos": self.input_pos},
                dynamic_shapes=self.dynamic_shapes,
            )
        et_program = to_edge(
            et_mha_ep,
            compile_config=EdgeCompileConfig(
                _core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
            ),
        ).to_executorch()
        runtime = Runtime.get()
        program = runtime.load_program(et_program.buffer)
        method = program.load_method("forward")
        et_res = method.execute((self.x, self.x, self.input_pos))
        tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

        assert_close(et_res[0], tt_res)

    def test_attention_torch_cond_eager(self):
        # Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
        # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
        self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
        self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

        # mask
        mask = self.causal_mask[self.input_pos, :]
        # First run
        et_res = self.et_mha(
            self.x, self.x, mask=mask, input_pos=self.input_pos
        )  # Self attention with input pos.
        tt_res = self.tt_mha(
            self.x, self.x, mask=mask, input_pos=self.input_pos
        )  # Self attention with input pos.

        self.assertTrue(torch.allclose(et_res, tt_res))

        # Second run test kv cache read. Input pos is [10, 11, ..., 19]
        next_input_pos = torch.arange(10, 20).unsqueeze(0)

        empty_y = torch.full_like(self.x, torch.nan)
        mask = self.causal_mask[next_input_pos, :]
        et_res = self.et_mha(
            self.x, empty_y, mask=mask, input_pos=next_input_pos
        )  # Self attention with input pos.
        tt_res = self.tt_mha(
            self.x, None, mask=mask, input_pos=next_input_pos
        )  # Self attention with input pos.

        assert_close(et_res, tt_res)
