import operator_benchmark as op_bench

import torch
import torch.ao.nn.quantized as nnq
import torch.ao.quantization as tq
import torch.nn as nn


"""Microbenchmarks for general quantization operations."""

# mode is used to show the direction of the benchmark:
# if 'Q', benchmark quantization, else dequantization

quantize_configs_short_dict = {
    "attr_names": ["C", "M", "N", "dtype", "mode"],
    "attrs": [
        [3, 512, 512, torch.quint8, "Q"],
        [3, 512, 512, torch.quint8, "D"],
    ],
    "tags": ["short"],
}

quantize_configs_long_dict = {
    "C": [3, 5, 8],  # this is reused for per-channel: avoid single channel test
    "M": [256, 1024],
    "N": [256, 1024],
    "dtype": [torch.quint8, torch.qint8, torch.qint32],
    "mode": ["D", "Q"],
    "tags": ["long"],
}


quantize_per_tensor_configs_short = op_bench.config_list(**quantize_configs_short_dict)

quantize_per_tensor_configs_long = op_bench.cross_product_configs(
    **quantize_configs_long_dict
)


class QuantizePerTensorBenchmark(op_bench.TorchBenchmarkBase):
    r"""Benchmarks both quantization and dequantization."""

    def init(self, C, M, N, dtype, mode):
        assert mode in ("Q", "D")
        self.input = torch.rand(C, M, N)
        self.dtype = dtype
        self.op = nnq.Quantize(scale=1.0, zero_point=0, dtype=dtype)
        self.set_module_name("QuantizePerTensor")

        if mode == "D":
            self.input = self.op(self.input)
            self.op = nnq.DeQuantize()
            self.set_module_name("DequantizePerTensor")

        self.inputs = {"input": self.input}

    def forward(self, input):
        return self.op(input)


op_bench.generate_pt_test(
    quantize_per_tensor_configs_short + quantize_per_tensor_configs_long,
    QuantizePerTensorBenchmark,
)

# === Per Channel quantization ===

quantize_per_channel_configs_short = op_bench.config_list(
    cross_product_configs={"axis": (0,)}, **quantize_configs_short_dict
)

quantize_per_channel_configs_long = op_bench.cross_product_configs(
    axis=(0, 1, 2), **quantize_configs_long_dict
)


class QuantizePerChannelBenchmark(op_bench.TorchBenchmarkBase):
    r"""Benchmarks both quantization and dequantization."""

    def init(self, C, M, N, dtype, axis, mode):
        assert mode in ("Q", "D")
        self.input = torch.rand(C, M, N)
        self.op = torch.quantize_per_channel

        channel_len = (C, M, N)[axis]

        self.kwargs = {
            "scales": torch.tensor([1.0] * channel_len),
            "zero_points": torch.tensor([0] * channel_len),
            "dtype": dtype,
            "axis": axis,
        }

        self.set_module_name("QuantizePerChannel")

        if mode == "D":
            self.input = self.op(self.input, **self.kwargs)

            def dequant(input, scales, zero_points, axis: int, dtype: int):
                return input.dequantize()

            self.op = dequant
            self.set_module_name("DequantizePerChannel")

        self.inputs = {
            "input": self.input,
            "scales": torch.tensor([1.0] * channel_len),
            "zero_points": torch.tensor([0] * channel_len),
            "axis": axis,
            "dtype": dtype,
        }

    def forward(self, input, scales, zero_points, axis: int, dtype: int):
        return self.op(
            input, scales=scales, zero_points=zero_points, axis=axis, dtype=dtype
        )


op_bench.generate_pt_test(
    quantize_per_channel_configs_short + quantize_per_channel_configs_long,
    QuantizePerChannelBenchmark,
)

# === Fake Quantization ===
# Generated benchmarks names start with 'learnable_kernel' or 'original_kernel',
#    for ex. 'original_kernel_nbits8_cpu_N1_C1_H256_W256_zero_point_dtypetorch.int32_bwdall'

fake_quantize_configs_short_dict = {
    "attr_names": ["N", "C", "H", "W", "zero_point_dtype"],
    "attrs": [
        [1, 3, 512, 512, torch.int32],
    ],
    "tags": ["short"],
}

fake_quantize_configs_long_dict = {
    "N": [1],
    "C": [1, 3, 8, 32],
    "H": [256, 1024],
    "W": [256, 1024],
    "zero_point_dtype": [torch.int32],
    "tags": ["long"],
}

fake_quantize_configs_short = op_bench.config_list(
    cross_product_configs={
        "device": ("cpu", "cuda"),
    },
    **fake_quantize_configs_short_dict,
)

fake_quantize_configs_long = op_bench.cross_product_configs(
    device=("cpu", "cuda"), **fake_quantize_configs_long_dict
)


class FakeQuantizeBenchmark(op_bench.TorchBenchmarkBase):
    r"""Benchmarks fake quantization with default parameters."""

    def init(self, N, C, H, W, zero_point_dtype, device):
        self.inputs = {"input": torch.rand(N, C, H, W).to(device)}
        self.op = tq.FakeQuantize().to(device)
        self.set_module_name("FakeQuantize")

    def forward(self, input):
        return self.op(input)


op_bench.generate_pt_test(
    fake_quantize_configs_short + fake_quantize_configs_long, FakeQuantizeBenchmark
)


# op_type is used to describe the type of operator used in benchmarking:
# learnable_kernel represents the c++ kernel that can backpropagate on
# scale and zero point.
# original_kernel represents the original fake quantize c++ kernel.


def fakeQuantizePerTensorLearnableKernel(
    input, scale, zero_point, quant_min: int, quant_max: int
):
    return torch._fake_quantize_learnable_per_tensor_affine(
        input, scale, zero_point, quant_min, quant_max
    )


def fakeQuantizePerTensorOriginalKernel(
    input, scale, zero_point, quant_min: int, quant_max: int
):
    return torch.fake_quantize_per_tensor_affine(input, 1.0, 0, quant_min, quant_max)


fake_quantize_per_tensor_ops = op_bench.op_list(
    attrs=(
        ("learnable_kernel", fakeQuantizePerTensorLearnableKernel),
        ("original_kernel", fakeQuantizePerTensorOriginalKernel),
    ),
    attr_names=("op_name", "op_func"),
)

fake_quantize_operator_configs_short = op_bench.config_list(
    cross_product_configs={
        "nbits": (4, 8),
        "device": ("cpu", "cuda"),
    },
    **fake_quantize_configs_short_dict,
)

fake_quantize_operator_configs_long = op_bench.cross_product_configs(
    nbits=(4, 8), device=("cpu", "cuda"), **fake_quantize_configs_long_dict
)

# TODO(future PR) Combine config for floating point zero_point with other configs, once it is
# fully supported in all fakeQuant operators and devices for
# https://github.com/pytorch/pytorch/issues/61866.
fake_quantize_configs_long_dict_float_zero_point = (
    fake_quantize_configs_long_dict.copy()
)
fake_quantize_configs_long_dict_float_zero_point["zero_point_dtype"] = [
    torch.float32,
    torch.half,
]

fake_quantize_operator_configs_long_float_zero_point = op_bench.cross_product_configs(
    nbits=(8,),
    device=("cpu", "cuda"),
    **fake_quantize_configs_long_dict_float_zero_point,
)


class FakeQuantizePerTensorBaseOpBenchmark(op_bench.TorchBenchmarkBase):
    r"""Benchmarks 3 different fake quantize per tensor operators."""

    def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
        self.quant_min = 0
        self.quant_max = 2**nbits - 1
        self.quant_range = 2**nbits
        self.input = nn.Parameter(
            torch.rand(N, C, H, W, dtype=torch.float, device=device),
            requires_grad=self.auto_set(),
        )
        self.scale = nn.Parameter(
            torch.tensor([1.0]).to(device), requires_grad=self.auto_set()
        )
        if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
            self.zero_point = nn.Parameter(
                torch.tensor([0.0]).to(device).to(zero_point_dtype),
                requires_grad=self.auto_set(),
            )
        else:
            self.zero_point = nn.Parameter(
                torch.tensor([0.0]).to(device), requires_grad=self.auto_set()
            )

        self.inputs = {
            "input": self.input,
            "scale": self.scale,
            "zero_point": self.zero_point,
            "quant_min": self.quant_min,
            "quant_max": self.quant_max,
        }
        self.op_func = op_func

    def forward(self, input, scale, zero_point, quant_min: int, quant_max: int):
        return self.op_func(input, scale, zero_point, quant_min, quant_max)


op_bench.generate_pt_tests_from_op_list(
    fake_quantize_per_tensor_ops,
    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
    FakeQuantizePerTensorBaseOpBenchmark,
)

op_bench.generate_pt_gradient_tests_from_op_list(
    fake_quantize_per_tensor_ops,
    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
    FakeQuantizePerTensorBaseOpBenchmark,
)


def fakeQuantizePerChannelLearnableKernel(
    input, scale, zero_point, axis: int, quant_min: int, quant_max: int
):
    return torch._fake_quantize_learnable_per_channel_affine(
        input, scale, zero_point, axis, quant_min, quant_max
    )


def fakeQuantizePerChannelOriginalKernel(
    input, scale, zero_point, axis: int, quant_min: int, quant_max: int
):
    return torch.fake_quantize_per_channel_affine(
        input, scale, zero_point, axis, quant_min, quant_max
    )


fake_quantize_per_channel_ops = op_bench.op_list(
    attrs=(
        ("learnable_kernel", fakeQuantizePerChannelLearnableKernel),
        ("original_kernel", fakeQuantizePerChannelOriginalKernel),
    ),
    attr_names=("op_name", "op_func"),
)

fake_quantize_per_channel_float_zero_point_ops = op_bench.op_list(
    attrs=(("original_kernel", fakeQuantizePerChannelOriginalKernel),),
    attr_names=("op_name", "op_func"),
)


class FakeQuantizePerChannelOpBenchmark(op_bench.TorchBenchmarkBase):
    r"""Benchmarks 3 different fake quantize per channel operators."""

    def init(self, N, C, H, W, zero_point_dtype, nbits, device, op_func):
        self.quant_min = 0
        self.quant_max = 2**nbits - 1
        self.quant_range = 2**nbits
        # Axis is chosen with respect to the number of channels: C.
        self.axis = 1
        self.input = nn.Parameter(
            torch.rand(
                N,
                C,
                H,
                W,
                dtype=torch.float,
                device=device,
                requires_grad=self.auto_set(),
            )
        )

        if op_func.__name__ == "fakeQuantizePerChannelOriginalKernel":
            self.scale = torch.ones(
                C, device=device, dtype=torch.float32, requires_grad=False
            )
            self.zero_point = torch.zeros(
                C, device=device, dtype=zero_point_dtype, requires_grad=False
            )
        else:
            self.scale = nn.Parameter(
                torch.ones(C, device=device, dtype=torch.float32),
                requires_grad=self.auto_set(),
            )
            self.zero_point = nn.Parameter(
                torch.zeros(C, device=device, dtype=torch.float32),
                requires_grad=self.auto_set(),
            )

        self.inputs = {
            "input": self.input,
            "scale": self.scale,
            "zero_point": self.zero_point,
            "axis": self.axis,
            "quant_min": self.quant_min,
            "quant_max": self.quant_max,
        }

        self.op_func = op_func

    def forward(
        self, input, scale, zero_point, axis: int, quant_min: int, quant_max: int
    ):
        return self.op_func(input, scale, zero_point, axis, quant_min, quant_max)


op_bench.generate_pt_tests_from_op_list(
    fake_quantize_per_channel_ops,
    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
    FakeQuantizePerChannelOpBenchmark,
)

op_bench.generate_pt_tests_from_op_list(
    fake_quantize_per_channel_float_zero_point_ops,
    fake_quantize_operator_configs_long_float_zero_point,
    FakeQuantizePerChannelOpBenchmark,
)

op_bench.generate_pt_gradient_tests_from_op_list(
    fake_quantize_per_channel_ops,
    fake_quantize_operator_configs_short + fake_quantize_operator_configs_long,
    FakeQuantizePerChannelOpBenchmark,
)

if __name__ == "__main__":
    op_bench.benchmark_runner.main()
