# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
    AnnotateChannelsLastDimOrder,
)
from executorch.backends.arm._passes.cast_int64_pass import CastInt64ToInt32Pass
from executorch.backends.arm._passes.conv1d_unsqueeze_pass import Conv1dUnsqueezePass
from executorch.backends.arm._passes.convert_expand_copy_to_repeat import (
    ConvertExpandCopyToRepeatPass,
)
from executorch.backends.arm._passes.convert_split_to_slice import (
    ConvertSplitToSlicePass,
)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
from executorch.backends.arm._passes.decompose_layernorm_pass import (
    DecomposeLayerNormPass,
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
    DecomposeSoftmaxesPass,
)
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
    InsertSqueezeAfterSumPass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
    ConvertMeanDimToAveragePool,
)
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
    ScalarsToAttributePass,
)
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
    UnsqueezeScalarPlaceholdersPass,
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager


class ArmPassManager(PassManager):

    def _transform(self, graph_module: torch.fx.GraphModule):
        return self(graph_module).graph_module

    def transform_to_backend_pipeline(
        self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
    ):
        """Apply passes before transforming program to backend"""
        self.add_pass(CastInt64ToInt32Pass(exported_program))
        self.add_pass(RemoveGetItemPass())
        self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
        self.add_pass(SizeAdjustConv2DPass())
        self.add_pass(RemoveClonePass())
        self.add_pass(ConvertExpandCopyToRepeatPass())
        self.add_pass(DecomposeLayerNormPass())
        self.add_pass(DecomposeVarPass())
        self.add_pass(ConvertMeanDimToAveragePool())
        self.add_pass(DecomposeMeanDimPass())
        self.add_pass(MatchArgRanksPass(exported_program))
        self.add_pass(DecomposeDivPass())
        self.add_pass(InsertSqueezeAfterSumPass())
        self.add_pass(ConvertSplitToSlicePass())
        self.add_pass(Conv1dUnsqueezePass(exported_program))
        self.add_pass(DecomposeSoftmaxesPass())
        self.add_pass(DecomposeLinearPass())
        for spec in compile_spec:
            if spec.key == "permute_memory_format":
                memory_format = spec.value.decode()
                if memory_format == "nhwc":
                    self.add_pass(AnnotateChannelsLastDimOrder())

        return self._transform(exported_program.graph_module)

    def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
        self.add_pass(ScalarsToAttributePass())
        self.add_pass(DecomposeLayerNormPass())
        self.add_pass(DecomposeVarPass())
        self.add_pass(DecomposeMeanDimPass())
        self.add_pass(DecomposeDivPass())
        self.add_pass(DecomposeSoftmaxesPass())
        return self._transform(graph_module)
