# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import Optional

import torch
from transformers import PretrainedConfig, StaticCache


class ETStaticCache(StaticCache):
    """
    A customized static cache implementation, which overrides a few methods to make it exportable to ExecuTorch.
    This can be removed once transformers supports static cache for Phi3 properly.
    """

    def __init__(
        self,
        config: PretrainedConfig,
        max_batch_size: int,
        max_cache_len: int,
        device,
        dtype=torch.float32,
    ) -> None:
        super().__init__(
            config=config,
            max_batch_size=max_batch_size,
            max_cache_len=max_cache_len,
            device=device,
            dtype=dtype,
        )

    def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
        # pyre-fixme[16]: `ETStaticCache` has no attribute `key_cache`.
        return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum().item()

    def get_usable_length(
        self, new_seq_length: int, layer_idx: Optional[int] = 0
    ) -> int:
        return self.get_seq_length(layer_idx)
