import operator_benchmark as op_bench

import torch


"""Microbenchmarks for quantized unary operators (point-wise and reduction)."""


# Configs for pointwise and reduction unary ops
qunary_ops_configs_short = op_bench.config_list(
    attr_names=["M", "N"],
    attrs=[
        [512, 512],
    ],
    cross_product_configs={
        "dtype": [torch.quint8],
    },
    tags=["short"],
)

qunary_ops_configs_long = op_bench.cross_product_configs(
    M=[256, 1024],
    N=[256, 1024],
    dtype=[torch.quint8, torch.qint8, torch.qint32],
    tags=["long"],
)


class QUnaryOpBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, M, N, dtype, op_func):
        f_input = torch.rand(M, N)
        scale = 1.0
        zero_point = 0
        self.inputs = {
            "q_input": torch.quantize_per_tensor(
                f_input, scale=scale, zero_point=zero_point, dtype=dtype
            )
        }
        self.op_func = op_func

    def forward(self, q_input):
        return self.op_func(q_input)


# TODO: Uncomment the ops whenever they are implemented for quantized tensor.
qunary_ops_list = op_bench.op_list(
    attr_names=["op_name", "op_func"],
    attrs=[
        # ['q_abs', torch.abs],
        # ['q_abs_', torch.abs_],
        # ['q_acos', torch.acos],
        # ['q_acos_', torch.acos_],
        ["q_argsort", torch.argsort],
        # ['q_asin', torch.asin],
        # ['q_asin_', torch.asin_],
        # ['q_atan', torch.atan],
        # ['q_atan_', torch.atan_],
        # ['q_ceil', torch.ceil],
        # ['q_ceil_', torch.ceil_],
        ["q_clone", torch.clone],
        # ['q_cos', torch.cos],
        # ['q_cos_', torch.cos_],
        # ['q_cosh', torch.cosh],
        # ['q_digamma', torch.digamma],
        # ['q_erf', torch.erf],
        # ['q_erf_', torch.erf_],
        # ['q_erfc', torch.erfc],
        # ['q_erfc_', torch.erfc_],
        # ['q_erfinv', torch.erfinv],
        # ['q_exp', torch.exp],
        # ['q_exp_', torch.exp_],
        # ['q_expm1', torch.expm1],
        # ['q_expm1_', torch.expm1_],
        # ['q_floor', torch.floor],
        # ['q_floor_', torch.floor_],
        # ['q_frac', torch.frac],
        # ['q_frac_', torch.frac_],
        # ['q_hardshrink', torch.hardshrink],
        # ['q_lgamma', torch.lgamma],
        # ['q_log', torch.log],
        # ['q_log10', torch.log10],
        # ['q_log10_', torch.log10_],
        # ['q_log1p', torch.log1p],
        # ['q_log1p_', torch.log1p_],
        # ['q_log2', torch.log2],
        # ['q_log2_', torch.log2_],
        # ['q_log_', torch.log_],
        ["q_mean", torch.mean],
        # ['q_neg', torch.neg],
        # ['q_neg_', torch.neg_],
        # ['q_reciprocal', torch.reciprocal],
        # ['q_reciprocal_', torch.reciprocal_],
        ["q_relu", torch.relu],
        ["q_relu_", torch.relu_],
        # ['q_round', torch.round],
        # ['q_round_', torch.round_],
        # ['q_rsqrt', torch.rsqrt],
        # ['q_rsqrt_', torch.rsqrt_],
        # ['q_sigmoid', torch.sigmoid],
        # ['q_sigmoid_', torch.sigmoid_],
        # ['q_sign', torch.sign],
        # ['q_sin', torch.sin],
        # ['q_sin_', torch.sin_],
        # ['q_sinh', torch.sinh],
        ["q_sort", torch.sort],
        # ['q_sqrt', torch.sqrt],
        # ['q_sqrt_', torch.sqrt_],
        # ['q_tan', torch.tan],
        # ['q_tan_', torch.tan_],
        # ['q_tanh', torch.tanh],
        # ['q_tanh_', torch.tanh_],
        # ['q_trunc', torch.trunc],
        # ['q_trunc_', torch.trunc_],
        # ['q_unique', torch.unique],
        # ['q_zero_', torch.zero_],
        # ['q_bernoulli_', lambda t: t.bernoulli_()],
        # ['q_cauchy_', lambda t: t.cauchy_()],
        # ['q_digamma_', lambda t: t.digamma_()],
        # ['q_exponential_', lambda t: t.exponential_()],
        # ['q_normal_', lambda t: t.normal_()],
        # ['q_random_', lambda t: t.random_()],
        # ['q_sign_', lambda t: t.sign_()],
        # ['q_uniform_', lambda t: t.uniform_()],
        # ['q_half', lambda t: t.half()],
        # ['q_long', lambda t: t.long()],
    ],
)


op_bench.generate_pt_tests_from_op_list(
    qunary_ops_list,
    qunary_ops_configs_short + qunary_ops_configs_long,
    QUnaryOpBenchmark,
)


# === Other unary ops (i.e. the ones that need parameters as args) ===

# Configs for pointwise and reduction unary ops
qunary_ops_topk_configs_short = op_bench.config_list(
    attr_names=["M", "N", "k"],
    attrs=[
        [512, 512, 5],
    ],
    cross_product_configs={
        "dtype": [torch.quint8],
    },
    tags=["short"],
)

qunary_ops_topk_configs_long = op_bench.cross_product_configs(
    M=[256, 1024],
    N=[256, 1024],
    k=[1, 3, 5],
    dtype=[torch.quint8, torch.qint8, torch.qint32],
    tags=["long"],
)


class QTopkOpBenchmark(op_bench.TorchBenchmarkBase):
    def init(self, M, N, dtype, k):
        f_input = torch.rand(M, N)
        scale = 1.0
        zero_point = 0
        self.inputs = {
            "q_input": torch.quantize_per_tensor(
                f_input, scale=scale, zero_point=zero_point, dtype=dtype
            ),
            "k": k,
        }
        self.set_module_name("qtopk")

    def forward(self, q_input, k: int):
        return torch.topk(q_input, k)


op_bench.generate_pt_test(
    qunary_ops_topk_configs_short + qunary_ops_topk_configs_long, QTopkOpBenchmark
)


if __name__ == "__main__":
    op_bench.benchmark_runner.main()
