# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch

from executorch.backends.xnnpack.test.tester import Tester
from executorch.examples.models.llama.model import Llama2Model


class TestLlama2ETExample(unittest.TestCase):
    def test_f32(self):
        self._test()

    def test_f16(self):
        self._test(torch.float16)

    # TODO - dynamic shape

    def _test(self, dtype: torch.dtype = torch.float):
        assert dtype in [
            torch.float,
            torch.float16,
        ], f"Only fp32 and fp16 are supported, but got dtype: {dtype}"

        llama2 = Llama2Model()
        model = llama2.get_eager_model().to(dtype)

        # Only convert fp32 inputs to dtype
        example_inputs = tuple(
            tensor.to(dtype) if tensor.dtype == torch.float32 else tensor
            for tensor in llama2.get_example_inputs()
        )

        (
            Tester(model, example_inputs)
            .export()
            .to_edge_transform_and_lower()
            .to_executorch()
            .serialize()
            .run_method_and_compare_outputs(atol=5e-2, inputs=example_inputs)
        )
