import torch

from . import benchmark


class RNNEltwise(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, b, hs):
        super().__init__(mode, device, dtype)
        self.b = b
        self.hs = hs
        self.input = self.rand(
            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.hx = self.rand(
            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.cx = self.rand(
            [b, hs], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.b_ih = self.rand(
            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.b_hh = self.rand(
            [b, 4 * hs], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [
            self.input,
            self.hx,
            self.cx,
            self.b_ih,
            self.b_hh,
        ]

    def forward(self, input, hx, cx, b_ih, b_hh):
        gates = input + hx + b_ih + b_hh

        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)

        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        cy = (forgetgate * cx) + (ingate * cellgate)
        hy = outgate * torch.tanh(cy)

        return hy, cy

    def config(self):
        return [self.b, self.hs]

    @staticmethod
    def module():
        return "rnn_eltwise"

    def memory_workload(self):
        def memsize(t):
            return t.numel() * t.element_size()

        input_size = sum(memsize(t) for t in self.inputs)
        output_size = 2 * memsize(self.cx)
        io_size = input_size + output_size
        return {"sol": io_size, "algorithmic": io_size}

    @staticmethod
    def default_configs():
        return [[64, 512]]


benchmark.register_benchmark_class(RNNEltwise)


class DynamicLSTM(benchmark.DynamicShape, RNNEltwise):
    def __init__(self, mode, device, dtype, b, hs):
        benchmark.DynamicShape.__init__(self)
        RNNEltwise.__init__(self, mode, device, dtype, b, hs)

    def instantiate_input(self):
        b, hs = self.rand_shape([self.b, self.hs])

        self.input = self.rand(
            [b, 4 * hs],
            device=self.device,
            dtype=self.dtype,
            requires_grad=self.requires_grad,
        )
        self.hx = self.rand(
            [b, 4 * hs],
            device=self.device,
            dtype=self.dtype,
            requires_grad=self.requires_grad,
        )
        self.cx = self.rand(
            [b, hs],
            device=self.device,
            dtype=self.dtype,
            requires_grad=self.requires_grad,
        )
        self.b_ih = self.rand(
            [b, 4 * hs],
            device=self.device,
            dtype=self.dtype,
            requires_grad=self.requires_grad,
        )
        self.b_hh = self.rand(
            [b, 4 * hs],
            device=self.device,
            dtype=self.dtype,
            requires_grad=self.requires_grad,
        )
        self.inputs = [
            self.input,
            self.hx,
            self.cx,
            self.b_ih,
            self.b_hh,
        ]

    @staticmethod
    def module():
        return "dynamic_lstm"


benchmark.register_benchmark_class(DynamicLSTM)
