import operator_benchmark as op_bench

import torch
import torch.ao.quantization.observer as obs


qobserver_short_configs_dict = {
    "attr_names": ("C", "M", "N", "dtype", "device"),
    "attrs": (
        (3, 512, 512, torch.quint8, "cpu"),
        (3, 512, 512, torch.quint8, "cuda"),
    ),
    "tags": ("short",),
}

q_hist_observer_short_configs_dict = {
    "attr_names": ("C", "M", "N", "dtype", "device"),
    "attrs": ((3, 512, 512, torch.quint8, "cpu"),),
    "tags": ("short",),
}

qobserver_long_configs_dict = {
    "C": (32, 64),
    "M": (256, 1024),
    "N": (256, 1024),
    "device": ("cpu", "cuda"),
    "dtype": (torch.quint8,),  # dtype doesn't change the timing, keep the same
    "tags": ("long",),
}

q_hist_observer_long_configs_dict = {
    "C": (1, 3, 8),
    "M": (256, 1024),
    "N": (256, 1024),
    "device": ("cpu",),
    "dtype": (torch.quint8,),  # dtype doesn't change the timing, keep the same
    "tags": ("long",),
}


qobserver_per_tensor_configs_short = op_bench.config_list(
    cross_product_configs={
        "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
    },
    **qobserver_short_configs_dict,
)

qobserver_per_tensor_configs_long = op_bench.cross_product_configs(
    qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
    **qobserver_long_configs_dict,
)

qobserver_per_channel_configs_short = op_bench.config_list(
    cross_product_configs={
        "qscheme": (torch.per_channel_affine, torch.per_channel_symmetric)
    },
    **qobserver_short_configs_dict,
)

qobserver_per_channel_configs_long = op_bench.cross_product_configs(
    qscheme=(torch.per_channel_affine, torch.per_channel_symmetric),
    **qobserver_long_configs_dict,
)

q_hist_observer_per_tensor_configs_short = op_bench.config_list(
    cross_product_configs={
        "qscheme": (torch.per_tensor_affine, torch.per_tensor_symmetric)
    },
    **q_hist_observer_short_configs_dict,
)

q_hist_observer_per_tensor_configs_long = op_bench.cross_product_configs(
    qscheme=(torch.per_tensor_affine, torch.per_tensor_symmetric),
    **q_hist_observer_long_configs_dict,
)


qobserver_per_tensor_list = op_bench.op_list(
    attr_names=["op_name", "op_func"],
    attrs=[
        ["MinMaxObserver", obs.MinMaxObserver],
        ["MovingAverageMinMaxObserver", obs.MovingAverageMinMaxObserver],
    ],
)

qobserver_per_channel_list = op_bench.op_list(
    attr_names=["op_name", "op_func"],
    attrs=[
        ["PerChannelMinMaxObserver", obs.PerChannelMinMaxObserver],
        [
            "MovingAveragePerChannelMinMaxObserver",
            obs.MovingAveragePerChannelMinMaxObserver,
        ],
    ],
)

q_hist_observer_list = op_bench.op_list(
    attr_names=["op_name", "op_func"],
    attrs=[
        ["HistogramObserver", obs.HistogramObserver],
        ["HistogramObserverCalculateQparams", obs.HistogramObserver],
    ],
)


class QObserverBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, C, M, N, dtype, qscheme, op_func, device):
        self.inputs = {"f_input": torch.rand(C, M, N, device=device)}
        self.op_func = op_func(dtype=dtype, qscheme=qscheme).to(device)

    def forward(self, f_input):
        self.op_func(f_input)
        return self.op_func.calculate_qparams()


class QObserverBenchmarkCalculateQparams(op_bench.TorchBenchmarkBase):
    def init(self, C, M, N, dtype, qscheme, op_func, device):
        self.f_input = torch.rand(C, M, N, device=device)
        self.q_observer = op_func(dtype=dtype, qscheme=qscheme).to(device)
        self.q_observer(self.f_input)
        self.inputs = {}

    def forward(self):
        return self.q_observer.calculate_qparams()


op_bench.generate_pt_tests_from_op_list(
    qobserver_per_tensor_list,
    qobserver_per_tensor_configs_short + qobserver_per_tensor_configs_long,
    QObserverBenchmark,
)

op_bench.generate_pt_tests_from_op_list(
    qobserver_per_channel_list,
    qobserver_per_channel_configs_short + qobserver_per_channel_configs_long,
    QObserverBenchmark,
)

op_bench.generate_pt_tests_from_op_list(
    q_hist_observer_list,
    q_hist_observer_per_tensor_configs_short + q_hist_observer_per_tensor_configs_long,
    QObserverBenchmarkCalculateQparams,
)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()
