import itertools
import operator

import numpy as np

import torch

from . import benchmark


class BroadcastMulBench(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, case, M, N, K):
        super().__init__(mode, device, dtype)
        self.case = case
        self.M = M
        self.N = N
        self.K = K

        if case == "row":
            self.d1 = self.rand(
                [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
            )
            self.d2 = self.rand(
                [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
            )
        elif case == "mid":
            self.d1 = self.rand(
                [M, N, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
            )
            self.d2 = self.rand(
                [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
            )
        elif case == "col":
            self.d1 = self.rand(
                [M, 1, K], device=device, dtype=dtype, requires_grad=self.requires_grad
            )
            self.d2 = self.rand(
                [1, N, K], device=device, dtype=dtype, requires_grad=self.requires_grad
            )
        else:
            raise ValueError(f"invalid case: {case}")

        self.inputs = [self.d1, self.d2]

    def forward(self, d1, d2):
        y = d1 + d2
        return y

    def reference(self):
        return self.numpy(self.d1) + self.numpy(self.d2)

    def config(self):
        return [self.M, self.N, self.K]

    @staticmethod
    def default_configs():
        return [[128, 256, 128]]

    def memory_workload(self):
        if self.mode == "fwd":
            sol_count = 1
            algorithmic_count = 1
        else:
            sol_count = (1) + (1)
            algorithmic_count = 1 + (1 + 1)

        buffer_size = self.M * self.N * self.K
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }


class BroadcastRowBench(BroadcastMulBench):
    def __init__(self, mode, device, dtype, M, N, K):
        super().__init__(mode, device, dtype, "row", M, N, K)

    @staticmethod
    def module():
        return "broadcast_row"


class BroadcastMidBench(BroadcastMulBench):
    def __init__(self, mode, device, dtype, M, N, K):
        super().__init__(mode, device, dtype, "mid", M, N, K)

    @staticmethod
    def module():
        return "broadcast_mid"


class BroadcastColBench(BroadcastMulBench):
    def __init__(self, mode, device, dtype, M, N, K):
        super().__init__(mode, device, dtype, "col", M, N, K)

    @staticmethod
    def module():
        return "broadcast_col"


class BroadcastThreeArgs(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, M, N, K, L):
        super().__init__(mode, device, dtype)
        self.M = M
        self.N = N
        self.K = K
        self.L = L

        self.d1 = self.rand(
            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d2 = self.rand(
            [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d3 = self.rand(
            [L, K, 1, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
        )

        self.inputs = [self.d1, self.d2, self.d3]

    def forward(self, d1, d2, d3):
        y = d1 + d2 + d3
        return y

    def reference(self):
        return self.numpy(self.d1) + self.numpy(self.d2) + self.numpy(self.d3)

    def config(self):
        return [self.M, self.N, self.K, self.L]

    @staticmethod
    def default_configs():
        return [[32, 16, 64, 128]]

    def memory_workload(self):
        if self.mode == "fwd":
            sol_count = 1
            algorithmic_count = 1
        else:
            sol_count = (1) + (1)
            algorithmic_count = 1 + (1 + 1 + 1)

        buffer_size = self.M * self.N * self.K * self.L * 4
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }

    @staticmethod
    def module():
        return "broadcast_3args"


# benchmark.register_benchmark_class(BroadcastRowBench)
# benchmark.register_benchmark_class(BroadcastMidBench)
# benchmark.register_benchmark_class(BroadcastColBench)
# benchmark.register_benchmark_class(BroadcastThreeArgs)


# TODO: merge this with elementwise bench
# A template class for elementwise operations.
# A derived class will override the class instance to customize its behavior.
class BroadcastBench(benchmark.Benchmark):
    # List of customization class variables.
    op_str = None
    binary_op_pt_func = None
    binary_op_np_func = None
    unary_op_pt_func = None
    unary_op_np_func = None
    split_input = True

    def __init__(self, mode, device, dtype, M, N, K):
        super().__init__(mode, device, dtype)
        self.M = M
        self.N = N
        self.K = K
        self.d1 = self.rand(
            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d2 = self.rand(
            [K, 1, N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d3 = self.rand(
            [M, N], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.d4 = self.rand(
            [K, M, 1], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [self.d1, self.d2, self.d3, self.d4]

    def _eval(self, d1, d2, d3, d4, binary_op, unary_op):
        if not binary_op:

            def binary_op(x, y):
                return x + y

        if not unary_op:

            def unary_op(x):
                return x

        if self.split_input:
            d1 = unary_op(d1)
            d2 = unary_op(d2)
            d3 = unary_op(d3)
            d4 = unary_op(d4)
        else:
            d1, d2, d3, d4 = (
                unary_op(d1),
                unary_op(d2),
                unary_op(d1 + 0.001),
                unary_op(d4),
            )
        a = binary_op(d1, d2)
        b = binary_op(d3, d4)
        c = a + b
        return c

    def forward(self, d1, d2, d3, d4):
        binary_op = self.__class__.binary_op_pt_func
        unary_op = self.__class__.unary_op_pt_func
        return self._eval(d1, d2, d3, d4, binary_op, unary_op)

    def reference(self):
        binary_op = self.__class__.binary_op_np_func
        unary_op = self.__class__.unary_op_np_func
        [d1, d2, d3, d4] = [self.numpy(d) for d in [self.d1, self.d2, self.d3, self.d4]]
        return self._eval(d1, d2, d3, d4, binary_op, unary_op)

    def config(self):
        return [self.M, self.N, self.K]

    @classmethod
    def module(cls):
        return "broadcast_" + cls.op_str

    def memory_workload(self):
        input_count = len(self.inputs)
        if self.mode == "fwd":
            if self.split_input:
                sol_count = 1
                algorithmic_count = 1
            else:
                sol_count = 1
                algorithmic_count = 1
        else:
            if self.split_input:
                sol_count = 1
                algorithmic_count = input_count
            else:
                sol_count = 1
                algorithmic_count = input_count

        buffer_size = self.M * self.N * self.K * 4
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }

    @staticmethod
    def default_configs():
        return [[1 << 8, 1 << 7, 1 << 9]]


def register_broadcast_ops():
    binary_op_list = [
        ["mul", operator.mul],
        ["add", operator.add],
        ["sub", operator.sub],
        ["div", lambda a, b: a / (b + 1e-4)],
        [
            "pow",
            torch.pow,
            np.power,
        ],  # no fuson triggered
        ["max", torch.max, np.maximum],
        ["min", torch.min, np.minimum],
    ]

    unary_op_list = [
        ["erf", torch.erf, np.erf],
        ["exp", torch.exp, np.exp],
        ["sin", torch.sin, np.sin],
        ["cos", torch.cos, np.cos],
    ]

    for split_input, binary_op in itertools.product([True, False], binary_op_list):
        # Make a copy of BroadcastBench
        if len(binary_op) == 2:
            [op_str, op_pt_func] = binary_op
            op_np_func = op_pt_func
        elif len(binary_op) == 3:
            [op_str, op_pt_func, op_np_func] = binary_op
        split_str = "split" if split_input else "shared"
        op_str = split_str + "_" + op_str
        bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
        bm_cls.op_str = op_str
        bm_cls.binary_op_pt_func = op_pt_func
        bm_cls.binary_op_np_func = op_np_func
        bm_cls.split_input = split_input
        benchmark.register_benchmark_class(bm_cls)

    for split_input, unary_op in itertools.product([True, False], unary_op_list):
        # Make a copy of BroadcastBench
        if len(unary_op) == 2:
            [op_str, op_pt_func] = unary_op
            op_np_func = op_pt_func
        elif len(unary_op) == 3:
            [op_str, op_pt_func, op_np_func] = unary_op
        split_str = "split" if split_input else "shared"
        op_str = split_str + "_" + op_str
        bm_cls = type("BroadcastBench_" + op_str, (BroadcastBench,), {})
        bm_cls.op_str = op_str
        bm_cls.unary_op_pt_func = op_pt_func
        bm_cls.unary_op_np_func = op_np_func
        bm_cls.split_input = split_input
        benchmark.register_benchmark_class(bm_cls)


register_broadcast_ops()
