# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded
# for benchmarking and some control flow stripped out.
# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py

import torch

from . import benchmark


class BahdanauAttention(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, b, t_q, t_k, n):
        super().__init__(mode, device, dtype)
        self.b = b
        self.t_q = t_q
        self.t_k = t_k
        self.n = n
        self.att_query = self.rand(
            [b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.att_keys = self.rand(
            [b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.normalize_bias = self.rand(
            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.linear_att = self.rand(
            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [
            self.att_query,
            self.att_keys,
            self.normalize_bias,
            self.linear_att,
        ]

    def forward(self, att_query, att_keys, normalize_bias, linear_att):
        """
        Calculate Bahdanau score

        :param att_query: b x t_q x n
        :param att_keys: b x t_k x n

        return b x t_q x t_k scores
        """

        b, t_k, n = att_keys.size()
        t_q = att_query.size(1)

        att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
        att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
        sum_qk = att_query + att_keys + normalize_bias
        out = torch.tanh(sum_qk).matmul(linear_att)
        return out

    def reference(self):
        return self.numpy(self.forward(*self.inputs))

    def config(self):
        return [self.b, self.t_q, self.t_k, self.n]

    @staticmethod
    def module():
        return "attention"

    def memory_workload(self):
        def memsize(t):
            return t.numel() * t.element_size()

        input_size = (
            memsize(self.att_query)
            + memsize(self.att_keys)
            + memsize(self.normalize_bias)
            + memsize(self.linear_att)
        )
        output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel()
        io_size = input_size + output_size

        # If matmul is not fused, must write and then read `sum_qk`.
        intermediate_size = (
            2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel()
        )
        return {"sol": io_size, "algorithmic": io_size + intermediate_size}

    @staticmethod
    def default_configs():
        mlperf_inference = [1280, 1, 66, 1024]
        nvidia = [128, 10, 128, 1024]
        return [mlperf_inference, nvidia]


benchmark.register_benchmark_class(BahdanauAttention)
