# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import os
import shutil
import tempfile
import unittest

import torch
from executorch.backends.arm.test import common

from executorch.backends.arm.test.tester.arm_tester import ArmTester

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class Linear(torch.nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int = 3,
        bias: bool = True,
    ):
        super().__init__()
        self.inputs = (torch.randn(5, 10, 25, in_features),)
        self.fc = torch.nn.Linear(
            in_features=in_features,
            out_features=out_features,
            bias=bias,
        )

    def get_inputs(self):
        return self.inputs

    def forward(self, x):
        return self.fc(x)


class TestDumpPartitionedArtifact(unittest.TestCase):
    """Tests dumping the partition artifact in ArmTester. Both to file and to stdout."""

    def _tosa_MI_pipeline(self, module: torch.nn.Module, dump_file=None):
        (
            ArmTester(
                module,
                example_inputs=module.get_inputs(),
                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
            )
            .export()
            .to_edge()
            .partition()
            .dump_artifact(dump_file)
            .dump_artifact()
        )

    def _tosa_BI_pipeline(self, module: torch.nn.Module, dump_file=None):
        (
            ArmTester(
                module,
                example_inputs=module.get_inputs(),
                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
            )
            .quantize()
            .export()
            .to_edge_transform_and_lower()
            .dump_artifact(dump_file)
            .dump_artifact()
        )

    def _is_tosa_marker_in_file(self, tmp_file):
        for line in open(tmp_file).readlines():
            if "'name': 'main'" in line:
                return True
        return False

    def test_MI_artifact(self):
        model = Linear(20, 30)
        tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_MI.txt")
        self._tosa_MI_pipeline(model, dump_file=tmp_file)
        assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
        if self._is_tosa_marker_in_file(tmp_file):
            return  # Implicit pass test
        self.fail("File does not contain TOSA dump!")

    def test_BI_artifact(self):
        model = Linear(20, 30)
        tmp_file = os.path.join(tempfile.mkdtemp(), "tosa_dump_BI.txt")
        self._tosa_BI_pipeline(model, dump_file=tmp_file)
        assert os.path.exists(tmp_file), f"File {tmp_file} was not created"
        if self._is_tosa_marker_in_file(tmp_file):
            return  # Implicit pass test
        self.fail("File does not contain TOSA dump!")


class TestNumericalDiffPrints(unittest.TestCase):
    """Tests trigging the exception printout from the ArmTester's run and compare function."""

    def test_numerical_diff_prints(self):
        model = Linear(20, 30)
        tester = (
            ArmTester(
                model,
                example_inputs=model.get_inputs(),
                compile_spec=common.get_tosa_compile_spec(
                    "TOSA-0.80.0+MI", permute_memory_to_nhwc=True
                ),
            )
            .export()
            .to_edge_transform_and_lower()
            .to_executorch()
        )
        # We expect an assertion error here. Any other issues will cause the
        # test to fail. Likewise the test will fail if the assertion error is
        # not present.
        try:
            # Tolerate 0 difference => we want to trigger a numerical diff
            tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0)
        except AssertionError:
            pass  # Implicit pass test
        else:
            self.fail()


def test_dump_ops_and_dtypes():
    model = Linear(20, 30)
    (
        ArmTester(
            model,
            example_inputs=model.get_inputs(),
            compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
        )
        .quantize()
        .dump_dtype_distribution()
        .dump_operator_distribution()
        .export()
        .dump_dtype_distribution()
        .dump_operator_distribution()
        .to_edge_transform_and_lower()
        .dump_dtype_distribution()
        .dump_operator_distribution()
    )
    # Just test that there are no execptions.


def test_dump_ops_and_dtypes_parseable():
    model = Linear(20, 30)
    (
        ArmTester(
            model,
            example_inputs=model.get_inputs(),
            compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
        )
        .quantize()
        .dump_dtype_distribution(print_table=False)
        .dump_operator_distribution(print_table=False)
        .export()
        .dump_dtype_distribution(print_table=False)
        .dump_operator_distribution(print_table=False)
        .to_edge_transform_and_lower()
        .dump_dtype_distribution(print_table=False)
        .dump_operator_distribution(print_table=False)
    )
    # Just test that there are no execptions.


class TestCollateTosaTests(unittest.TestCase):
    """Tests the collation of TOSA tests through setting the environment variable TOSA_TESTCASE_BASE_PATH."""

    def test_collate_tosa_BI_tests(self):
        # Set the environment variable to trigger the collation of TOSA tests
        os.environ["TOSA_TESTCASES_BASE_PATH"] = "test_collate_tosa_tests"
        # Clear out the directory

        model = Linear(20, 30)
        (
            ArmTester(
                model,
                example_inputs=model.get_inputs(),
                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
            )
            .quantize()
            .export()
            .to_edge_transform_and_lower()
            .to_executorch()
        )
        # test that the output directory is created and contains the expected files
        assert os.path.exists(
            "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests"
        )
        assert os.path.exists(
            "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/output_tag5.tosa"
        )
        assert os.path.exists(
            "test_collate_tosa_tests/tosa-bi/TestCollateTosaTests/test_collate_tosa_BI_tests/desc_tag5.json"
        )

        os.environ.pop("TOSA_TESTCASES_BASE_PATH")
        shutil.rmtree("test_collate_tosa_tests", ignore_errors=True)


def test_dump_tosa_ops(caplog):
    caplog.set_level(logging.INFO)
    model = Linear(20, 30)
    (
        ArmTester(
            model,
            example_inputs=model.get_inputs(),
            compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
        )
        .quantize()
        .export()
        .to_edge_transform_and_lower()
        .dump_operator_distribution()
    )
    assert "TOSA operators:" in caplog.text


def test_fail_dump_tosa_ops(caplog):
    caplog.set_level(logging.INFO)

    class Add(torch.nn.Module):
        def forward(self, x):
            return x + x

    model = Add()
    compile_spec = common.get_u55_compile_spec()
    (
        ArmTester(model, example_inputs=(torch.ones(5),), compile_spec=compile_spec)
        .quantize()
        .export()
        .to_edge_transform_and_lower()
        .dump_operator_distribution()
    )
    assert "Can not get operator distribution for Vela command stream." in caplog.text
