"""Common backbone across multiple models"""

import math

import numpy as np
import torch
from models.llm_models.configuration_base import BaseConfig
from models.llm_models.modeling_base import BaseModelChunk
from torch import nn
from torch.export import Dim

torch.manual_seed(42)
np.random.seed(42)


# flake8: noqa: C901


class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states


class Gelu(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.sigmoid(1.702 * x)


class MLP(nn.Module):
    def __init__(self, config: BaseConfig):
        super().__init__()
        hidden_size = config.hidden_size
        intermediate_size = config.intermediate_size

        self.gate_proj = nn.Linear(hidden_size, intermediate_size)
        self.down_proj = nn.Linear(intermediate_size, hidden_size)
        self.up_proj = nn.Linear(hidden_size, intermediate_size)

    def forward(self, x):
        gate = self.gate_proj(x)
        up = self.up_proj(x)
        pre_down = gate * torch.sigmoid(gate) * up
        down = self.down_proj(pre_down)

        return down


class Attention(nn.Module):
    def __init__(self, config: BaseConfig):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.attn_scale = math.sqrt(self.head_dim)

        if config.combine_qkv:
            self.qkv_proj = nn.Linear(
                self.hidden_size,
                (2 * self.num_key_value_heads * self.head_dim) + self.hidden_size,
            )
        else:
            self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim)
            self.k_proj = nn.Linear(
                self.hidden_size, self.num_key_value_heads * self.head_dim
            )
            self.v_proj = nn.Linear(
                self.hidden_size, self.num_key_value_heads * self.head_dim
            )
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size)

    def apply_rotary_pos_emb_mtk(self, q, k, cos, sin):
        q1 = q[..., : q.shape[-1] // 2]
        q2 = q[..., q.shape[-1] // 2 :]
        q_rotated = torch.cat((-q2, q1), dim=-1)
        k1 = k[..., : k.shape[-1] // 2]
        k2 = k[..., k.shape[-1] // 2 :]
        k_rotated = torch.cat((-k2, k1), dim=-1)

        q_embed = q * cos + q_rotated * sin
        k_embed = k * cos + k_rotated * sin
        return q_embed, k_embed

    def repeat_kv(self, hidden_states, batch, q_len, n_rep):
        if isinstance(hidden_states, list):
            output = []
            for hs in hidden_states:
                output.append(
                    hs.repeat(1, 1, n_rep, 1).view(batch, 1, q_len, self.head_dim)
                )
            return output
        else:
            hidden_states = hidden_states.repeat(1, 1, n_rep, 1)
            return hidden_states.view(batch, self.num_heads, q_len, self.head_dim)

    def forward(
        self,
        hidden_states,  # (b, t, 4096)
        mask,  # (b, 1, t, c+t)
        pos_emb,  # (b, 2, t, head dim)
        past_key,  # (b, num kv heads, c, head dim)
        past_value,  # (b, num kv heads, c, head dim)
    ):
        bsz, q_len, _ = hidden_states.size()
        c_len = past_key.size()[2]

        if self.config.combine_qkv:
            proj = self.qkv_proj(hidden_states)
            query_states = (
                proj[:, :, : self.config.hidden_size]
                .view(bsz, q_len, self.num_heads, self.head_dim)
                .transpose(1, 2)
            )
            key_states = (
                proj[
                    :,
                    :,
                    self.config.hidden_size : self.config.hidden_size
                    + self.num_key_value_heads * self.head_dim,
                ]
                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
                .transpose(1, 2)
            )
            value_states = (
                proj[
                    :,
                    :,
                    self.config.hidden_size
                    + self.num_key_value_heads * self.head_dim :,
                ]
                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
                .transpose(1, 2)
            )
        else:
            query_states = (
                self.q_proj(hidden_states)
                .view(bsz, q_len, self.num_heads, self.head_dim)
                .transpose(1, 2)
            )
            key_states = (
                self.k_proj(hidden_states)
                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
                .transpose(1, 2)
            )
            value_states = (
                self.v_proj(hidden_states)
                .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
                .transpose(1, 2)
            )

        if self.config.position_embedding == "rope":
            cos, sin = torch.split(pos_emb, 1, dim=1)
            query_states, key_states = self.apply_rotary_pos_emb_mtk(
                query_states, key_states, cos, sin
            )

        key_states = torch.cat([past_key, key_states], dim=2)
        value_states = torch.cat([past_value, value_states], dim=2)
        key_states_out = key_states
        value_states_out = value_states
        if self.num_key_value_groups > 1:
            key_states = self.repeat_kv(
                key_states, bsz, q_len + c_len, self.num_key_value_groups
            )
            value_states = self.repeat_kv(
                value_states, bsz, q_len + c_len, self.num_key_value_groups
            )
        attn_weights = (
            torch.matmul(query_states, key_states.transpose(2, 3)) / self.attn_scale
        )
        attn_weights = attn_weights + mask
        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        attn_output = torch.matmul(attn_weights, value_states)
        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        key_states_out = key_states_out[:, :, q_len:, :]
        value_states_out = value_states_out[:, :, q_len:, :]

        return attn_output, key_states_out, value_states_out


class DecoderLayer(nn.Module):
    def __init__(
        self,
        config: BaseConfig,
        return_attn=False,
        jit_trace=False,
        attn_class=Attention,
        mlp_class=MLP,
    ):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.return_attn = return_attn
        self.jit_trace = jit_trace
        self.self_attn = attn_class(config)
        self.mlp = mlp_class(config)
        if config.norm == "RMSNorm":
            self.input_norm = RMSNorm(config.hidden_size, eps=config.norm_eps).float()
            self.post_attention_norm = RMSNorm(
                config.hidden_size, eps=config.norm_eps
            ).float()
        else:
            self.input_norm = nn.LayerNorm(
                config.hidden_size, eps=config.norm_eps
            ).float()
            self.post_attention_norm = nn.LayerNorm(
                config.hidden_size, eps=config.norm_eps
            ).float()

    def forward(
        self,
        hidden_states,  # (b, t, hidden_dim)
        mask,  # (b, 1, t, c+t)
        pos_emb,  # (b, 2, t, head_dim)
        past_key,  # (b, num_kv_head, c, head_dim)
        past_value,  # (b, num_kv_head, c, head_dim)
    ):
        residual = hidden_states
        if self.jit_trace:
            hidden_states = self.input_norm(hidden_states)
        else:
            dtype = hidden_states.dtype
            hidden_states = self.input_norm(hidden_states.to(torch.float32)).to(dtype)

        layer_device = hidden_states.device

        # Self Attention
        attn_output, present_key, present_value = self.self_attn(
            hidden_states=hidden_states.to(layer_device),
            mask=mask.to(layer_device),
            pos_emb=pos_emb.to(layer_device),
            past_key=past_key.to(layer_device),
            past_value=past_value.to(layer_device),
        )
        hidden_states = residual.to(layer_device) + attn_output

        # Fully Connected
        residual = hidden_states
        if self.jit_trace:
            hidden_states = self.post_attention_norm(hidden_states)
        else:
            dtype = hidden_states.dtype
            hidden_states = self.post_attention_norm(
                hidden_states.to(torch.float32)
            ).to(dtype)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        if self.return_attn:
            return hidden_states, present_key, present_value, attn_output
        return hidden_states, present_key, present_value


class ModelChunk(BaseModelChunk):
    def __init__(
        self,
        config: BaseConfig,
        num_blocks,
        chunk_idx,
        dtype=torch.float32,
        include_tail=False,
        return_attn=False,
        jit_trace=False,
        decoder_class=DecoderLayer,
    ):
        super().__init__(
            config, num_blocks, chunk_idx, dtype, include_tail, return_attn, jit_trace
        )
        self.head_dim = config.hidden_size // config.num_attention_heads
        self.layers = nn.ModuleList(
            [
                decoder_class(config, return_attn=return_attn, jit_trace=jit_trace)
                for _ in range(num_blocks)
            ]
        )

        if self.config.use_stable_embedding and self.chunk_idx == 0:
            self.embed_layer_norm = nn.LayerNorm(config.hidden_size).float()

        if self.include_tail:
            if config.norm == "RMSNorm":
                self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps).float()
            else:
                self.norm = nn.LayerNorm(
                    config.hidden_size, eps=config.norm_eps
                ).float()
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)

    def forward(self, inputs_embeds, mask, pos_emb, *cache):
        if not self.jit_trace:
            assert (
                len(cache) == 2 * self.num_blocks
            ), f"split cache wrong number of input caches: {len(cache)} != 2*{self.num_blocks}"
            assert (
                cache[0].shape[0] == inputs_embeds.size()[0]
            ), f"split cache batch size mismatch: {cache[0].shape[0]} != {inputs_embeds.size()[0]}"

        inputs_embeds = inputs_embeds.to(self.device_list[0])

        if self.config.use_stable_embedding and self.chunk_idx == 0:
            if self.jit_trace:
                inputs_embeds = self.embed_layer_norm(inputs_embeds)
            else:
                inputs_embeds = self.embed_layer_norm(
                    inputs_embeds.to(torch.float32)
                ).to(self.dtype)

        hidden_states = inputs_embeds

        next_key_cache = []
        next_value_cache = []
        if self.return_attn:
            attn_outputs = []

        # decoder layers
        for idx, decoder_layer in enumerate(self.layers):
            decoder_outputs = decoder_layer(
                hidden_states.to(self.device_list[idx]),
                mask=mask.to(self.device_list[idx]),
                pos_emb=pos_emb.to(self.device_list[idx]),
                past_key=cache[idx].to(self.device_list[idx]),
                past_value=cache[self.num_blocks + idx].to(self.device_list[idx]),
            )
            hidden_states = decoder_outputs[0]
            next_key_cache.append(decoder_outputs[1].to(inputs_embeds.device))
            next_value_cache.append(decoder_outputs[2].to(inputs_embeds.device))
            if self.return_attn:
                attn_outputs.append(decoder_outputs[3].to(inputs_embeds.device))

        if self.include_tail:
            if self.jit_trace:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states = self.norm(hidden_states.to(torch.float32)).to(
                    self.dtype
                )
            hidden_states = self.lm_head(hidden_states)

        if self.return_attn:
            return hidden_states, *next_key_cache, *next_value_cache, *attn_outputs
        return hidden_states, *next_key_cache, *next_value_cache

    def load_weights(self, state_dict, state_dict_start_idx):
        if state_dict is None:
            fake_weights = True
        else:
            expected_subkey = f"layers.{state_dict_start_idx}.self_attn.o_proj.weight"
            state_dict_keys = list(state_dict.keys())
            temp_key = None
            input_norm_subkey = None
            post_attention_norm_subkey = None
            for key in state_dict_keys:
                if expected_subkey in key:
                    temp_key = key
                if (
                    f"layers.{state_dict_start_idx}" in key
                    and "norm" in key
                    and "input" in key
                ):
                    input_norm_subkey = key.split(".")[-2]
                if (
                    f"layers.{state_dict_start_idx}" in key
                    and "norm" in key
                    and "post_attention" in key
                ):
                    post_attention_norm_subkey = key.split(".")[-2]
            if temp_key is None:
                raise KeyError(
                    f"Cannot find layer {state_dict_start_idx}'s o_proj weight inside state_dict. "
                    f"Please ensure o_proj weight key contains: {expected_subkey}"
                )
            if input_norm_subkey is None:
                raise KeyError(
                    f"Cannot find layer {state_dict_start_idx}'s input norm weight inside state_dict. "
                    f"Please ensure input norm weight key contains: layers.{state_dict_start_idx}, norm, and input inside"
                    " the key string."
                )
            if post_attention_norm_subkey is None:
                raise KeyError(
                    f"Cannot find layer {state_dict_start_idx}'s post attention norm weight inside state_dict."
                    f" Please ensure post attention norm weight key contains: layers.{state_dict_start_idx}, norm, and "
                    "post_attention inside the key string."
                )
            prefix = temp_key.split(expected_subkey)[0]
            fake_weights = False

        outer_layer_idx = state_dict_start_idx
        self.device_list = []
        if self.config.use_stable_embedding and self.chunk_idx == 0:
            if fake_weights:
                temp_state_dict = {
                    "embed_layer_norm.weight": torch.rand(
                        self.config.hidden_size, dtype=torch.float32
                    ),
                    "embed_layer_norm.bias": torch.zeros(
                        self.config.hidden_size, dtype=torch.float32
                    ),
                }
            else:
                temp_state_dict = {
                    "embed_layer_norm.weight": state_dict.pop(
                        f"{prefix}embed_layer_norm.weight"
                    ).to(torch.float32),
                    "embed_layer_norm.bias": state_dict.pop(
                        f"{prefix}embed_layer_norm.bias",
                        torch.zeros(self.config.hidden_size, dtype=self.dtype),
                    ).to(torch.float32),
                }
        else:
            temp_state_dict = {}

        for inner_layer_idx in range(self.num_blocks):
            if fake_weights:
                if self.config.combine_qkv:
                    temp_state_dict[
                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.weight"
                    ] = torch.rand(
                        3 * self.config.hidden_size,
                        self.config.hidden_size,
                        dtype=self.dtype,
                    )
                    temp_state_dict[
                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.bias"
                    ] = torch.zeros(
                        (2 * self.config.num_key_value_heads * self.head_dim)
                        + self.config.hidden_size,
                        dtype=self.dtype,
                    )
                else:
                    temp_state_dict = {
                        **temp_state_dict,
                        **{
                            f"layers.{inner_layer_idx}.self_attn.q_proj.weight": torch.rand(
                                self.config.hidden_size,
                                self.config.hidden_size,
                                dtype=self.dtype,
                            ),
                            f"layers.{inner_layer_idx}.self_attn.k_proj.weight": torch.rand(
                                self.config.num_key_value_heads * self.head_dim,
                                self.config.hidden_size,
                                dtype=self.dtype,
                            ),
                            f"layers.{inner_layer_idx}.self_attn.v_proj.weight": torch.rand(
                                self.config.num_key_value_heads * self.head_dim,
                                self.config.hidden_size,
                                dtype=self.dtype,
                            ),
                            f"layers.{inner_layer_idx}.self_attn.q_proj.bias": torch.zeros(
                                self.config.hidden_size, dtype=self.dtype
                            ),
                            f"layers.{inner_layer_idx}.self_attn.k_proj.bias": torch.zeros(
                                self.config.num_key_value_heads * self.head_dim,
                                dtype=self.dtype,
                            ),
                            f"layers.{inner_layer_idx}.self_attn.v_proj.bias": torch.zeros(
                                self.config.num_key_value_heads * self.head_dim,
                                dtype=self.dtype,
                            ),
                        },
                    }
                temp_state_dict = {
                    **temp_state_dict,
                    **{
                        f"layers.{inner_layer_idx}.self_attn.o_proj.weight": torch.rand(
                            self.config.hidden_size,
                            self.config.hidden_size,
                            dtype=self.dtype,
                        ),
                        f"layers.{inner_layer_idx}.mlp.gate_proj.weight": torch.rand(
                            self.config.intermediate_size,
                            self.config.hidden_size,
                            dtype=self.dtype,
                        ),
                        f"layers.{inner_layer_idx}.mlp.down_proj.weight": torch.rand(
                            self.config.hidden_size,
                            self.config.intermediate_size,
                            dtype=self.dtype,
                        ),
                        f"layers.{inner_layer_idx}.mlp.up_proj.weight": torch.rand(
                            self.config.intermediate_size,
                            self.config.hidden_size,
                            dtype=self.dtype,
                        ),
                        f"layers.{inner_layer_idx}.input_norm.weight": torch.rand(
                            self.config.hidden_size, dtype=torch.float32
                        ),
                        f"layers.{inner_layer_idx}.post_attention_norm.weight": torch.rand(
                            self.config.hidden_size, dtype=torch.float32
                        ),
                        f"layers.{inner_layer_idx}.self_attn.o_proj.bias": torch.zeros(
                            self.config.hidden_size, dtype=self.dtype
                        ),
                        f"layers.{inner_layer_idx}.mlp.gate_proj.bias": torch.zeros(
                            self.config.intermediate_size, dtype=self.dtype
                        ),
                        f"layers.{inner_layer_idx}.mlp.down_proj.bias": torch.zeros(
                            self.config.hidden_size, dtype=self.dtype
                        ),
                        f"layers.{inner_layer_idx}.mlp.up_proj.bias": torch.zeros(
                            self.config.intermediate_size, dtype=self.dtype
                        ),
                    },
                }

                if self.config.norm == "LayerNorm":
                    temp_state_dict = {
                        **temp_state_dict,
                        **{
                            f"layers.{inner_layer_idx}.input_norm.bias": torch.zeros(
                                self.config.hidden_size, dtype=torch.float32
                            ),
                            f"layers.{inner_layer_idx}.post_attention_norm.bias": torch.zeros(
                                self.config.hidden_size, dtype=torch.float32
                            ),
                        },
                    }

            else:
                if self.config.combine_qkv:
                    temp_state_dict[
                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.weight"
                    ] = state_dict.pop(
                        f"{prefix}layers.{outer_layer_idx}.self_attn.qkv_proj.weight"
                    )
                    temp_state_dict[
                        f"layers.{inner_layer_idx}.self_attn.qkv_proj.bias"
                    ] = state_dict.pop(
                        f"{prefix}layers.{outer_layer_idx}.self_attn.qkv_proj.bias",
                        torch.zeros(
                            (2 * self.config.num_key_value_heads * self.head_dim)
                            + self.config.hidden_size,
                            dtype=self.dtype,
                        ),
                    )
                else:
                    temp_state_dict = {
                        **temp_state_dict,
                        **{
                            f"layers.{inner_layer_idx}.self_attn.q_proj.weight": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.self_attn.q_proj.weight"
                            ),
                            f"layers.{inner_layer_idx}.self_attn.k_proj.weight": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.self_attn.k_proj.weight"
                            ),
                            f"layers.{inner_layer_idx}.self_attn.v_proj.weight": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.self_attn.v_proj.weight"
                            ),
                            f"layers.{inner_layer_idx}.self_attn.q_proj.bias": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.self_attn.q_proj.bias",
                                torch.zeros(self.config.hidden_size, dtype=self.dtype),
                            ),
                            f"layers.{inner_layer_idx}.self_attn.k_proj.bias": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.self_attn.k_proj.bias",
                                torch.zeros(
                                    self.config.num_key_value_heads * self.head_dim,
                                    dtype=self.dtype,
                                ),
                            ),
                            f"layers.{inner_layer_idx}.self_attn.v_proj.bias": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.self_attn.v_proj.bias",
                                torch.zeros(
                                    self.config.num_key_value_heads * self.head_dim,
                                    dtype=self.dtype,
                                ),
                            ),
                        },
                    }

                temp_state_dict = {
                    **temp_state_dict,
                    **{
                        f"layers.{inner_layer_idx}.self_attn.o_proj.weight": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.self_attn.o_proj.weight"
                        ),
                        f"layers.{inner_layer_idx}.mlp.gate_proj.weight": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.mlp.gate_proj.weight"
                        ),
                        f"layers.{inner_layer_idx}.mlp.down_proj.weight": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.mlp.down_proj.weight"
                        ),
                        f"layers.{inner_layer_idx}.mlp.up_proj.weight": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.mlp.up_proj.weight"
                        ),
                        f"layers.{inner_layer_idx}.input_norm.weight": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.{input_norm_subkey}.weight"
                        ).to(torch.float32),
                        f"layers.{inner_layer_idx}.post_attention_norm.weight": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.{post_attention_norm_subkey}.weight"
                        ).to(
                            torch.float32
                        ),
                        f"layers.{inner_layer_idx}.self_attn.o_proj.bias": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.self_attn.o_proj.bias",
                            torch.zeros(self.config.hidden_size, dtype=self.dtype),
                        ),
                        f"layers.{inner_layer_idx}.mlp.gate_proj.bias": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.mlp.gate_proj.bias",
                            torch.zeros(
                                self.config.intermediate_size, dtype=self.dtype
                            ),
                        ),
                        f"layers.{inner_layer_idx}.mlp.down_proj.bias": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.mlp.down_proj.bias",
                            torch.zeros(self.config.hidden_size, dtype=self.dtype),
                        ),
                        f"layers.{inner_layer_idx}.mlp.up_proj.bias": state_dict.pop(
                            f"{prefix}layers.{outer_layer_idx}.mlp.up_proj.bias",
                            torch.zeros(
                                self.config.intermediate_size, dtype=self.dtype
                            ),
                        ),
                    },
                }

                if self.config.norm == "LayerNorm":
                    temp_state_dict = {
                        **temp_state_dict,
                        **{
                            f"layers.{inner_layer_idx}.input_norm.bias": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.{input_norm_subkey}.bias",
                                torch.zeros(self.config.hidden_size, dtype=self.dtype),
                            ).to(torch.float32),
                            f"layers.{inner_layer_idx}.post_attention_norm.bias": state_dict.pop(
                                f"{prefix}layers.{outer_layer_idx}.{post_attention_norm_subkey}.bias",
                                torch.zeros(self.config.hidden_size, dtype=self.dtype),
                            ).to(
                                torch.float32
                            ),
                        },
                    }

            if torch.cuda.device_count() == 0 or self.jit_trace:
                self.device_list.append("cpu")
            else:
                device_id = outer_layer_idx // (
                    self.config.num_hidden_layers // torch.cuda.device_count()
                    + (self.config.num_hidden_layers % torch.cuda.device_count() != 0)
                )
                self.device_list.append(f"cuda:{device_id}")
            outer_layer_idx += 1
        if self.include_tail:
            if fake_weights:
                temp_state_dict = {
                    **temp_state_dict,
                    "norm.weight": torch.rand(
                        self.config.hidden_size, dtype=torch.float32
                    ),
                    "lm_head.weight": torch.rand(
                        self.config.vocab_size,
                        self.config.hidden_size,
                        dtype=self.dtype,
                    ),
                    "lm_head.bias": torch.zeros(
                        self.config.vocab_size, dtype=self.dtype
                    ),
                }
                if self.config.norm == "LayerNorm":
                    temp_state_dict["norm.bias"] = torch.zeros(
                        self.config.hidden_size, dtype=torch.float32
                    )
            else:
                if self.config.tie_word_embeddings:
                    lm_head_weight_key = f"{prefix}embed_tokens.weight"
                    lm_head_bias_key = f"{prefix}embed_tokens.bias"
                else:
                    lm_head_weight_key = "lm_head.weight"
                    lm_head_bias_key = "lm_head.bias"
                temp_state_dict = {
                    **temp_state_dict,
                    **{
                        "lm_head.weight": state_dict.pop(lm_head_weight_key),
                        "norm.weight": state_dict.pop(f"{prefix}norm.weight").to(
                            torch.float32
                        ),
                        "lm_head.bias": state_dict.pop(
                            lm_head_bias_key,
                            torch.zeros(self.config.vocab_size, dtype=self.dtype),
                        ),
                    },
                }
                if self.config.norm == "LayerNorm":
                    temp_state_dict["norm.bias"] = state_dict.pop(
                        f"{prefix}norm.bias",
                        torch.zeros(self.config.hidden_size, dtype=self.dtype),
                    ).to(torch.float32)

        print(f"Loading weights for chunk {self.chunk_idx}")
        if temp_state_dict.keys() != self.state_dict().keys():
            temp_state_dict_only_keys = [
                x for x in temp_state_dict.keys() if x not in self.state_dict().keys()
            ]
            model_only_keys = [
                x for x in self.state_dict().keys() if x not in temp_state_dict.keys()
            ]
            raise RuntimeError(
                f"model state dict keys don't match with state_dict to load into model.\nModel only keys:{model_only_keys}\nstate_dict only keys:{temp_state_dict_only_keys}"
            )
        self.load_state_dict(temp_state_dict)
        for i in range(self.num_blocks):
            self.layers[i].to(self.device_list[i])
        if self.config.use_stable_embedding and self.chunk_idx == 0:
            self.embed_layer_norm.to(self.device_list[0])
        if self.include_tail:
            self.norm.to(self.device_list[-1])
            self.lm_head.to(self.device_list[-1])
        self.eval()

        return self

    def get_example_inputs(
        self, num_token: int = 128, cache_size: int = 512, get_dym_shape=False
    ):
        head_dim = int(self.config.hidden_size / self.config.num_attention_heads)
        example_inputs = (
            torch.randn(
                1, num_token, self.config.hidden_size, device="cpu", dtype=torch.float32
            ),
            torch.randn(
                1,
                1,
                num_token,
                cache_size + num_token,
                device="cpu",
                dtype=torch.float32,
            ),
            torch.randn(1, 2, num_token, head_dim, device="cpu", dtype=torch.float32),
            *[
                torch.randn(
                    1,
                    self.config.num_key_value_heads,
                    cache_size,
                    head_dim,
                    device="cpu",
                    dtype=torch.float32,
                )
                for _ in range(2 * self.num_blocks)
            ],
        )
        # Specify dims that would be dynamic during calibration
        # Note: Assume cache size fixed shape as torch dynamic shape cannot handle dim 3 being
        # combination of 2 dynamic dims
        if get_dym_shape:
            nt = Dim("num_token", max=num_token)
            cache_dims = tuple(({} for _ in range(2 * self.num_blocks)))
            dynamic_shapes = (
                {0: Dim.STATIC, 1: nt, 2: Dim.STATIC},
                {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: nt + cache_size},
                {0: Dim.STATIC, 1: Dim.STATIC, 2: nt, 3: Dim.STATIC},
                cache_dims,
            )
            return example_inputs, dynamic_shapes

        return example_inputs
