# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import argparse
import copy
import logging
import time

import torch
from executorch.exir import EdgeCompileConfig
from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.extension.export_util.utils import export_to_edge, save_pte_program
from torch.ao.ns.fx.utils import compute_sqnr
from torch.ao.quantization import (  # @manual
    default_per_channel_symmetric_qnnpack_qconfig,
    QConfigMapping,
)
from torch.ao.quantization.backend_config import get_executorch_backend_config
from torch.ao.quantization.quantize_fx import (
    _convert_to_reference_decomposed_fx,
    prepare_fx,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    get_symmetric_quantization_config,
    XNNPACKQuantizer,
)

from ...models import MODEL_NAME_TO_MODEL
from ...models.model_factory import EagerModelFactory

from .. import MODEL_NAME_TO_OPTIONS
from .utils import quantize


FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)


def verify_xnnpack_quantizer_matching_fx_quant_model(model_name, model, example_inputs):
    """This is a verification against fx graph mode quantization flow as a sanity check"""

    if model_name in ["edsr", "mobilebert"]:
        # EDSR has control flows that are not traceable in symbolic_trace
        # mobilebert is not symbolically traceable with torch.fx.symbolic_trace
        return
    if model_name == "ic3":
        # we don't want to compare results of inception_v3 with fx, since mul op with Scalar
        # input is quantized differently in fx, and we don't want to replicate the behavior
        # in XNNPACKQuantizer
        return

    model.eval()
    m_copy = copy.deepcopy(model)
    m = model

    # 1. pytorch 2.0 export quantization flow (recommended/default flow)
    m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module()
    quantizer = XNNPACKQuantizer()
    quantization_config = get_symmetric_quantization_config(is_per_channel=True)
    quantizer.set_global(quantization_config)
    m = prepare_pt2e(m, quantizer)
    # calibration
    after_prepare_result = m(*example_inputs)
    logging.info(f"prepare_pt2e: {m}")
    m = convert_pt2e(m)
    after_quant_result = m(*example_inputs)

    # 2. the previous fx graph mode quantization reference flow
    qconfig = default_per_channel_symmetric_qnnpack_qconfig
    qconfig_mapping = QConfigMapping().set_global(qconfig)
    backend_config = get_executorch_backend_config()
    m_fx = prepare_fx(
        m_copy, qconfig_mapping, example_inputs, backend_config=backend_config
    )
    after_prepare_result_fx = m_fx(*example_inputs)
    logging.info(f"prepare_fx: {m_fx}")
    m_fx = _convert_to_reference_decomposed_fx(m_fx, backend_config=backend_config)
    after_quant_result_fx = m_fx(*example_inputs)

    # 3. compare results
    if model_name == "dl3":
        # dl3 output format: {"out": a, "aux": b}
        after_prepare_result = after_prepare_result["out"]
        after_prepare_result_fx = after_prepare_result_fx["out"]
        after_quant_result = after_quant_result["out"]
        after_quant_result_fx = after_quant_result_fx["out"]
    logging.info(f"m: {m}")
    logging.info(f"m_fx: {m_fx}")
    logging.info(
        f"prepare sqnr: {compute_sqnr(after_prepare_result, after_prepare_result_fx)}"
    )

    # NB: this check is more useful for QAT since for PTQ we are only inserting observers that does not change the
    # output of a model, so it's just testing the numerical difference for different captures in PTQ
    # for QAT it is also testing whether the fake quant placement match or not
    # not exactly the same due to capture changing numerics, but still really close
    assert compute_sqnr(after_prepare_result, after_prepare_result_fx) > 100
    logging.info(
        f"quant diff max: {torch.max(after_quant_result - after_quant_result_fx)}"
    )
    assert torch.max(after_quant_result - after_quant_result_fx) < 1e-1
    logging.info(
        f"quant sqnr: {compute_sqnr(after_quant_result, after_quant_result_fx)}"
    )
    assert compute_sqnr(after_quant_result, after_quant_result_fx) > 30


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-m",
        "--model_name",
        required=True,
        help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_OPTIONS.keys())}",
    )
    parser.add_argument(
        "-ve",
        "--verify",
        action="store_true",
        required=False,
        default=False,
        help="flag for verifying XNNPACKQuantizer against fx graph mode quantization",
    )
    parser.add_argument(
        "-s",
        "--so_library",
        required=False,
        help="shared library for quantized operators",
    )

    args = parser.parse_args()
    # See if we have quantized op out variants registered
    has_out_ops = True
    try:
        _ = torch.ops.quantized_decomposed.add.out
    except AttributeError:
        logging.info("No registered quantized ops")
        has_out_ops = False
    if not has_out_ops:
        if args.so_library:
            torch.ops.load_library(args.so_library)
        else:
            raise RuntimeError(
                "Need to specify shared library path to register quantized ops (and their out variants) into"
                "EXIR. The required shared library is defined as `quantized_ops_aot_lib` in "
                "kernels/quantized/CMakeLists.txt if you are using CMake build, or `aot_lib` in "
                "kernels/quantized/targets.bzl for buck2. One example path would be cmake-out/kernels/quantized/"
                "libquantized_ops_aot_lib.[so|dylib]."
            )
    if not args.verify and args.model_name not in MODEL_NAME_TO_OPTIONS:
        raise RuntimeError(
            f"Model {args.model_name} is not a valid name. or not quantizable right now, "
            "please contact executorch team if you want to learn why or how to support "
            "quantization for the requested model"
            f"Available models are {list(MODEL_NAME_TO_OPTIONS.keys())}."
        )

    start = time.perf_counter()
    model, example_inputs, _, _ = EagerModelFactory.create_model(
        *MODEL_NAME_TO_MODEL[args.model_name]
    )
    end = time.perf_counter()
    # logging.info(f"Model init time: {end - start}s")
    if args.verify:
        start = time.perf_counter()
        verify_xnnpack_quantizer_matching_fx_quant_model(
            args.model_name, model, example_inputs
        )
        end = time.perf_counter()
        # logging.info(f"Verify time: {end - start}s")

    model = model.eval()
    # pre-autograd export. eventually this will become torch.export
    model = torch.export.export_for_training(model, example_inputs).module()
    start = time.perf_counter()
    quantized_model = quantize(model, example_inputs)
    end = time.perf_counter()
    logging.info(f"Quantize time: {end - start}s")

    start = time.perf_counter()
    edge_compile_config = EdgeCompileConfig(_check_ir_validity=False)
    edge_m = export_to_edge(
        quantized_model, example_inputs, edge_compile_config=edge_compile_config
    )
    end = time.perf_counter()
    logging.info(f"Export time: {end - start}s")

    start = time.perf_counter()
    prog = edge_m.to_executorch(
        config=ExecutorchBackendConfig(extract_delegate_segments=False)
    )
    save_pte_program(prog, f"{args.model_name}_quantized")
    end = time.perf_counter()
    logging.info(f"Save time: {end - start}s")
    logging.info("finished")


if __name__ == "__main__":
    main()  # pragma: no cover
