# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional


def get_xnnpack_partitioner(dynamic_quant_only_partitioner: bool = True):
    """
    Returns the XNNPACK partitioner.

    @arg dynamic_quant_only_partitioner:
        This is enabled by default to keep BC.
        If dynamic_quant_only_partitioner is True, then only dynamically quantized
        linear layers will be partitioned.
        Else, anything which can be will be partitioned greedily.
    """
    from executorch.backends.xnnpack.partition.xnnpack_partitioner import (
        XnnpackDynamicallyQuantizedPartitioner,
        XnnpackPartitioner,
    )

    if dynamic_quant_only_partitioner:
        # Following changes due to.
        # 1. We need dynamically quantized partitioner for both pt2e_quantize options
        #    as well as "qmode 8da4w" which is also dynamic quantizes linear layers.
        # 2. XNNPACK partitioner seems to result in seg fault for non dqlinear ops.
        return XnnpackDynamicallyQuantizedPartitioner()
    return XnnpackPartitioner()


def get_vulkan_partitioner(
    dtype_override: Optional[str] = None, enable_dynamic_shape: bool = False
):
    assert (
        dtype_override == "fp32" or dtype_override is None
    ), "Vulkan backend does not support non fp32 dtypes at the moment"
    from executorch.backends.vulkan.partitioner.vulkan_partitioner import (
        VulkanPartitioner,
    )

    return VulkanPartitioner({"require_dynamic_shapes": enable_dynamic_shape})


def get_mps_partitioner(use_kv_cache: bool = False):
    from executorch.exir.backend.backend_details import CompileSpec

    assert (
        use_kv_cache is True
    ), "MPS backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
    try:
        # pyre-ignore Undefined import [21]: Could not find a module corresponding to import `executorch.backends.apple.mps.partition.mps_partitioner`.
        from executorch.backends.apple.mps.partition.mps_partitioner import (
            MPSPartitioner,
        )
    except ImportError:
        raise ImportError(
            "Please install the MPS backend follwing https://pytorch.org/executorch/main/build-run-mps.html"
        )

    compile_specs = [CompileSpec("use_fp16", bytes([True]))]
    return MPSPartitioner(compile_specs)  # pyre-fixme[16]


def get_coreml_partitioner(
    ios: int = 15,
    embedding_quantize: Optional[str] = None,
    pt2e_quantize: Optional[str] = None,
    coreml_quantize: Optional[str] = None,
):
    try:
        import coremltools as ct
        from executorch.backends.apple.coreml.compiler import (  # pyre-ignore
            CoreMLBackend,
        )
        from executorch.backends.apple.coreml.partition import (  # pyre-ignore
            CoreMLPartitioner,
        )
    except ImportError:
        raise ImportError(
            "Please install the CoreML backend follwing https://pytorch.org/executorch/main/build-run-coreml.html"
        )

    def _validate_ios_version() -> None:
        assert ios in (15, 16, 17, 18)

        if embedding_quantize is not None and ios < 18:
            raise ValueError(
                "In Core ML, per-block quantization is introduced in iOS 18"
            )

        use_quantization = pt2e_quantize is not None or coreml_quantize is not None
        if use_quantization and ios < 16:
            raise ValueError("In Core ML, quantization is introduced in iOS 16")

        use_8a = (pt2e_quantize is not None and "8a" in pt2e_quantize) or (
            coreml_quantize is not None and "8a" in coreml_quantize
        )
        if use_8a and ios < 17:
            raise ValueError(
                "In Core ML, 8-bit activation quantization is introduced in iOS 17"
            )

        use_4w = (pt2e_quantize is not None and "4w" in pt2e_quantize) or (
            coreml_quantize is not None and "4w" in coreml_quantize
        )
        if use_4w and ios < 18:
            raise ValueError(
                "In Core ML, 4-bit weight compression is introduced in iOS 18"
            )

    _validate_ios_version()

    minimum_deployment_target = {
        15: ct.target.iOS15,
        16: ct.target.iOS16,
        17: ct.target.iOS17,
        18: ct.target.iOS18,
    }[ios]
    op_linear_quantizer_config = None
    if coreml_quantize == "b4w":
        op_linear_quantizer_config = {
            "mode": "linear_symmetric",
            "dtype": "int4",
            "granularity": "per_block",
            "block_size": 32,
            "weight_threshold": 512,
        }
    compile_specs = CoreMLBackend.generate_compile_specs(  # pyre-fixme[16]
        minimum_deployment_target=minimum_deployment_target,
        compute_precision=ct.precision(ct.precision.FLOAT16.value),
        # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU`
        compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()],
        model_type=CoreMLBackend.MODEL_TYPE.MODEL,  # pyre-fixme[16]
        op_linear_quantizer_config=op_linear_quantizer_config,
    )

    take_over_mutable_buffer = minimum_deployment_target >= ct.target.iOS18

    return CoreMLPartitioner(  # pyre-fixme[16]
        compile_specs=compile_specs,
        take_over_mutable_buffer=take_over_mutable_buffer,
    )


def get_qnn_partitioner(
    use_kv_cache: bool = False,
    pt2e_quantize: Optional[str] = None,
    num_sharding: int = 0,
    soc_model: str = "SM8650",  # default to SM8650
):
    assert (
        use_kv_cache is True
    ), "Qualcomm backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment"
    try:
        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.partition.qnn_partitioner`
        from executorch.backends.qualcomm.partition.qnn_partitioner import (
            QnnPartitioner,
        )

        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qc_schema`
        from executorch.backends.qualcomm.serialization.qc_schema import QcomChipset

        # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
        from executorch.backends.qualcomm.utils.utils import (
            generate_htp_compiler_spec,
            generate_qnn_executorch_compiler_spec,
        )
    except ImportError:
        raise ImportError(
            "Please install the Qualcomm backend following https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html"
        )

    use_fp16 = True
    skip_node_op_set = {"llama.fallback.default"}
    if pt2e_quantize is not None:
        use_fp16 = False

    return QnnPartitioner(  # pyre-fixme[16]
        generate_qnn_executorch_compiler_spec(  # pyre-fixme[16]
            soc_model=getattr(QcomChipset, soc_model),  # pyre-fixme[16]
            # pyre-fixme[16]
            backend_options=generate_htp_compiler_spec(
                use_fp16=use_fp16,
                use_multi_contexts=num_sharding > 0,
            ),
            debug=False,
            saver=False,
        ),
        skip_node_id_set={},
        skip_node_op_set=skip_node_op_set,
    )
