import itertools
import operator

import numpy as np
import scipy.special

import torch

from . import benchmark


# A template class for elementwise operations.
# A derived class will override the class instance to customize its behavior.
class ElementBench(benchmark.Benchmark):
    # List of customization class variables.
    op_str = None
    binary_op_pt_func = None
    binary_op_np_func = None
    unary_op_pt_func = None
    unary_op_np_func = None
    split_input = True

    def __init__(self, mode, device, dtype, N):
        super().__init__(mode, device, dtype)
        self.N = N
        self.d1 = self.rand(
            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d2 = self.rand(
            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d3 = self.rand(
            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d4 = self.rand(
            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [self.d1, self.d2, self.d3, self.d4]
        self.deterministic = "rand" not in self.op_str

    def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
        if not binary_op:

            def binary_op(x, y):
                return x + y

        if not unary_op:

            def unary_op(x):
                return x

        if self.split_input:
            d1 = unary_op(d1)
            d2 = unary_op(d2)
            d3 = unary_op(d3)
            d4 = unary_op(d4)
        else:
            d2 = unary_op(d1 + 0.001)
            d3 = unary_op(d1 + 0.002)
            d4 = unary_op(d1 + 0.003)
            d1 = unary_op(d1)
        a = binary_op(d1, d2)
        b = binary_op(d3, d4)
        c = a + b
        return c

    def forward(self, d1, d2, d3, d4):
        binary_op = self.__class__.binary_op_pt_func
        unary_op = self.__class__.unary_op_pt_func
        return self._eval(d1, d2, d3, d4, binary_op, unary_op)

    def reference(self):
        binary_op = self.__class__.binary_op_np_func
        unary_op = self.__class__.unary_op_np_func
        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
        return self._eval(d1, d2, d3, d4, binary_op, unary_op)

    def config(self):
        return [self.N]

    @classmethod
    def module(cls):
        return "element_" + cls.op_str

    def memory_workload(self):
        input_count = len(self.inputs)
        if self.mode == "fwd":
            if self.split_input:
                sol_count = input_count + 1
                algorithmic_count = input_count + 1
            else:
                sol_count = 1 + 1
                algorithmic_count = 1 + 1
            if "rand" in self.op_str:
                sol_count = 1
                algorithmic_count = 1
        else:
            if self.split_input:
                sol_count = (input_count + 1) + (1 + input_count)
                algorithmic_count = (input_count + 1) + ((2 + 1) * input_count)
            else:
                sol_count = 1 + 1
                algorithmic_count = 1 + 1
            if "rand" in self.op_str:
                sol_count = 1
                algorithmic_count = 1

        buffer_size = self.N
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }

    @staticmethod
    def default_configs():
        return [[1 << 25]]


def register_element_ops():
    binary_op_list = [
        ["mul", operator.mul],
        ["add", operator.add],
        ["sub", operator.sub],
        ["div", lambda a, b: a / (b + 1e-4)],
        [
            "pow",
            torch.pow,
            np.power,
        ],  # no fuson triggered
        ["max", torch.max, np.maximum],
        ["min", torch.min, np.minimum],
    ]

    unary_op_list = [
        ["erf", torch.erf, scipy.special.erf],
        ["exp", torch.exp, np.exp],
        ["sin", torch.sin, np.sin],
        ["cos", torch.cos, np.cos],
        ["rand_like", torch.rand_like, lambda x: np.random.rand(*x.shape)],
    ]

    for split_input, binary_op in itertools.product([True, False], binary_op_list):
        # Make a copy of ElementBench
        if len(binary_op) == 2:
            [op_str, op_pt_func] = binary_op
            op_np_func = op_pt_func
        elif len(binary_op) == 3:
            [op_str, op_pt_func, op_np_func] = binary_op
        split_str = "split" if split_input else "shared"
        op_str = split_str + "_" + op_str
        bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
        bm_cls.op_str = op_str
        bm_cls.binary_op_pt_func = op_pt_func
        bm_cls.binary_op_np_func = op_np_func
        bm_cls.split_input = split_input
        benchmark.register_benchmark_class(bm_cls)

    for split_input, unary_op in itertools.product([True, False], unary_op_list):
        # Make a copy of ElementBench
        if len(unary_op) == 2:
            [op_str, op_pt_func] = unary_op
            op_np_func = op_pt_func
        elif len(unary_op) == 3:
            [op_str, op_pt_func, op_np_func] = unary_op
        split_str = "split" if split_input else "shared"
        op_str = split_str + "_" + op_str
        bm_cls = type("ElementBench_" + op_str, (ElementBench,), {})
        bm_cls.op_str = op_str
        bm_cls.unary_op_pt_func = op_pt_func
        bm_cls.unary_op_np_func = op_np_func
        bm_cls.split_input = split_input
        benchmark.register_benchmark_class(bm_cls)


# benchmark.register_benchmark_class(ElementMulBench)
register_element_ops()


class SimpleElementBench(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, N):
        super().__init__(mode, device, dtype)
        self.N = N
        self.data = self.rand(
            [N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [self.data]

    def forward(self, data):
        a = data + 0.001
        b = a + 0.002
        return b

    def reference(self):
        binary_op = self.__class__.binary_op_np_func
        unary_op = self.__class__.unary_op_np_func
        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
        return self._eval(d1, d2, d3, d4, binary_op, unary_op)

    def config(self):
        return [self.N]

    @staticmethod
    def input_iterable():
        return True

    @classmethod
    def module(cls):
        return "simple_element"

    def memory_workload(self):
        input_count = len(self.inputs)
        if self.mode == "fwd":
            sol_count = 2
            algorithmic_count = 2
        else:
            sol_count = 2
            algorithmic_count = 2

        buffer_size = self.N
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }

    @staticmethod
    def default_configs():
        return [[1 << 25]]


benchmark.register_benchmark_class(SimpleElementBench)


class DynamicSimpleElementBench(benchmark.DynamicShape, SimpleElementBench):
    def __init__(self, mode, device, dtype, N):
        benchmark.DynamicShape.__init__(self)
        SimpleElementBench.__init__(self, mode, device, dtype, N)

    @classmethod
    def module(cls):
        return "simple_dynamic_element"

    def instantiate_input(self):
        (N,) = self.rand_shape([self.N])
        data = self.rand(
            [N], device=self.device, dtype=self.dtype, requires_grad=self.requires_grad
        )
        self.inputs = [data]


benchmark.register_benchmark_class(DynamicSimpleElementBench)
