# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math

import torch
from functorch.dim import cat, dimlists, dims, softmax
from torch import nn


class Linear(nn.Linear):
    def forward(self, input):
        ci, co = dims()
        b = dimlists()
        result = (input[b, ci] * self.weight[co, ci]).sum(ci) + self.bias[co]
        return result.order(b, co)


class BertSelfAttention(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_attention_heads,
        attention_probs_dropout_prob,
        position_embedding_type=None,
        max_position_embeddings=None,
        linear=Linear,
    ):
        super().__init__()
        if hidden_size % num_attention_heads != 0:
            raise ValueError(
                f"The hidden size ({hidden_size}) is not a multiple of the number of attention "
                f"heads ({num_attention_heads})"
            )

        self.num_attention_heads = num_attention_heads
        self.attention_head_size = int(hidden_size / num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = linear(hidden_size, self.all_head_size)
        self.key = linear(hidden_size, self.all_head_size)
        self.value = linear(hidden_size, self.all_head_size)

        self.dropout_prob = attention_probs_dropout_prob
        self.position_embedding_type = position_embedding_type

        if self.position_embedding_type is not None:
            assert max_position_embeddings is not None
            self.max_position_embeddings = max_position_embeddings
            self.distance_embedding = nn.Embedding(
                2 * max_position_embeddings - 1, self.attention_head_size
            )

    def forward(
        self,
        hidden_states,
        past_key_value=None,
    ):
        # first run the encoding linear layers for q, k, v normally
        # the meaning of a linear layer is well understood, so no need to use explicit dimensions
        q = self.query(hidden_states)
        k = self.key(hidden_states)
        v = self.value(hidden_states)

        # introduce values that represent each dimension. dimensions are 'first class'
        # because they are actual python values introduced here
        batch, query_sequence, key_sequence, heads, features = dims()
        heads.size = self.num_attention_heads

        # bind the positional dimensions in k, q, and v against
        # our values. the sizes of each dimension are determined by this binding
        # and when a dimension is used twice (e.g. batch), its size against both
        # uses is checked for consistency.
        # The group (heads, features) splits apart a single positional dimension
        # into two dimensions. Since heads.size*features.size == q.size(2)
        # and we specified heads.size, features.size is inferred here.
        q = q[batch, query_sequence, [heads, features]]
        k = k[batch, key_sequence, [heads, features]]
        v = v[batch, key_sequence, [heads, features]]

        # this option allows the model to attend to not just the elements of the current sequence
        # but the previous elements as well as additional tokens.
        if past_key_value is not None:
            extended_key_sequence = dims()
            key_past = past_key_value[0][batch, heads, key_sequence, features]
            value_past = past_key_value[1][batch, heads, key_sequence, features]
            # cat introduces a new dimension extended_key_sequence, because it is twice as long
            # as the original key_sequence
            k = cat([key_past, k], key_sequence, extended_key_sequence)
            v = cat([value_past, v], key_sequence, extended_key_sequence)
            # for the rest of the function, we will just use extended_key_sequence in lieu of
            # key_sequence
            key_sequence = extended_key_sequence

        # Take the dot product between "query" and "key" to get the raw attention scores.
        # The actual outer-product and summation are explicitly represented here,
        # and like einsum, will be pattern matched to an efficient matrix multiply op.
        attention_scores = (q * k).sum(features) / math.sqrt(features.size)

        # relative positional embeddings gave a unique embedding based on the distance between
        # key and value tokens in the sequence, e.g.
        #  0  1  2  3
        # -1  0  1  2
        # -2 -1  0  1
        # -3 -2 -1  0
        if self.position_embedding_type is not None:
            # the value of a dimension object when used as a tensor is the indices along its dimension
            # so we can directly subtract the two dimensions to get a 2D tensor of (query_sequence x key_sequence)
            # with the distance between them
            distance = query_sequence - key_sequence

            assert key_sequence.size <= self.max_position_embeddings

            # we can then use that as an indirect index into the embedding table values to look up the features for that index
            # this is just a `gather` primitive op. The resulting tensor will
            # have all the dimensions of embeddeding_idx (query_sequence x key_sequence),
            # plus all the dimensions of `embed` that were not indirectly accessed (`embedding_range`).
            # this form of indirect indexing is more straightforward than either advanced indexing or torch.gather which both
            # have a lot of dependencies on the positions of indexing tensors.

            positional_embedding = self.distance_embedding.weight[
                self.max_position_embeddings - 1 + distance, features
            ]

            if self.position_embedding_type == "relative_key":
                # these were einsum ops in the positional code because they are not easy to fit to existing matmul operators
                # eventhough they are degenerate matmuls
                relative_position_scores = (q * positional_embedding).sum(features)
                attention_scores = attention_scores + relative_position_scores
            elif self.position_embedding_type == "relative_key_query":
                relative_position_scores_query = (q * positional_embedding).sum(
                    features
                )
                relative_position_scores_key = (k * positional_embedding).sum(features)
                attention_scores = (
                    attention_scores
                    + relative_position_scores_query
                    + relative_position_scores_key
                )

        attention_probs = attention_scores
        # Normalize the attention scores to probabilities.
        attention_probs = softmax(attention_scores, dim=key_sequence)
        # # This is actually dropping out entire tokens to attend to, which might
        # # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = torch.nn.functional.dropout(
            attention_probs, p=self.dropout_prob
        )

        # similarly, we can replace the matmul with a direct listing of the outer product, which makes it clear
        # we are weighting the values v across all keys with the attention scores.
        context_layer = (attention_probs * v).sum(key_sequence)

        # finally, we convert back to a standard tensor by describing the layout of dimensions.
        # working in reverse to with_dims, the (heads, features) group flattens the dimensions into a single one.
        return context_layer.order(batch, query_sequence, [heads, features])
