# Copyright (c) Qualcomm Innovation Center, Inc.
# All rights reserved
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch


# module with related operator only
class Add(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.add(x, y)


class AddConstantFloat(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 10.0 + x


class AddConstantLong(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 10 + x


class Arange(torch.nn.Module):
    def __init__(self, x):
        super().__init__()
        self.x = x

    def forward(self, y):
        return torch.arange(self.x, dtype=torch.float32) + y


class AvgPoolModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.avgPool = torch.nn.AvgPool2d(
            kernel_size=(2, 2),
            padding=(1, 1),
            stride=(1, 1),
            count_include_pad=False,
        )

    def forward(self, x):
        return self.avgPool(x)


class BatchNorm(torch.nn.Module):
    def __init__(self, n_features):
        super().__init__()
        self.native_batchnorm = torch.nn.BatchNorm2d(n_features)
        self.eval()

    def forward(self, x):
        return self.native_batchnorm(x)


class Bmm(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.matmul(x, y)


class Cast(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.type(torch.IntTensor)


class Cat2(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.cat((x, y), axis=2)


class Cat3(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.concat((y, y, x), axis=2)


class Cat4(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.cat((y, y, x, x), axis=2)


class Ceil(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.ceil(x)


class Chunk(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.chunk(x, chunks=2, dim=-1)


class ChunkAdd(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        c1, c2 = torch.chunk(x, chunks=2, dim=-1)
        return torch.add(c1, c2)


class Clamp(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.clamp(x, max=0)


class CompositeDelegateModule(torch.nn.Module):
    def __init__(
        self,
        compiler_specs,
        partitioner_type,
        capture_method,
        lowered_method,
        quantize_method=None,
    ) -> None:
        super().__init__()
        self.modules = [
            Conv2dSequential(),
            Conv2dSequential(),
            Add(),
            Relu(),
        ]
        self.sample_inputs = [
            (torch.randn([1, 1, 3, 3]),),
            (torch.randn([1, 1, 3, 3]),),
            (torch.randn([1, 2, 3, 3]), torch.randn([1, 2, 3, 3])),
            (torch.randn([1, 2, 3, 3]),),
        ]
        self.lowered_modules = []
        for module, sample_input in zip(self.modules, self.sample_inputs):
            partitioner = partitioner_type(compiler_specs)
            if quantize_method:
                module = quantize_method(module, sample_input)
            edge_prog = capture_method(module, sample_input)
            edge_prog.exported_program = lowered_method(
                edge_prog.exported_program, partitioner
            )
            self.lowered_modules.append(
                edge_prog.exported_program.graph_module._modules.get("lowered_module_0")
            )

    def forward(self, x, y):
        x1 = self.lowered_modules[0](x)
        x2 = self.lowered_modules[1](y)
        x3 = self.lowered_modules[2](x1[0], x2[0])
        x4 = self.lowered_modules[3](x3[0])
        return x4[0]

    def get_random_input(self):
        return (torch.randn([1, 1, 3, 3]), torch.randn([1, 1, 3, 3]))

    def get_reference_module(self):
        class CompositeReferenceModule(torch.nn.Module):
            def __init__(self, modules):
                super().__init__()
                self.modules = modules

            def forward(self, x, y):
                x1 = self.modules[0](x)
                x2 = self.modules[1](y)
                x3 = self.modules[2](x1, x2)
                x4 = self.modules[3](x3)
                return x4

        return CompositeReferenceModule(self.modules)


class ContextBinaryExample(torch.nn.Module):
    def forward(self, x, y):
        x = torch.nn.functional.relu(x)
        y = torch.nn.functional.relu(y)
        return x, y

    def example_inputs(self):
        return {
            "x": torch.randn((1, 3, 3, 3)),
            "y": torch.randn((2, 1, 5, 5)),
        }


class Conv1dSequential(torch.nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.first = torch.nn.Conv1d(
            in_channels=1,
            out_channels=3,
            kernel_size=(3),
            padding=1,
            bias=bias,
        )

        self.second = torch.nn.Conv1d(
            in_channels=3,
            out_channels=2,
            kernel_size=(3),
            padding=1,
            bias=bias,
        )

    def forward(self, x):
        return self.second(self.first(x))


# small models
class Conv1dReluLogSoftmax(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv1d(
            in_channels=2, out_channels=2, kernel_size=1, stride=1, padding=1
        )
        self.logsoftmax = torch.nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = torch.nn.functional.relu(self.conv(x))
        x = self.logsoftmax(x)
        return x


class Conv2dAvgPool2d(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            3, 16, 7, bias=True, stride=2, padding=3, dilation=1
        )
        self.pool = torch.nn.AvgPool2d(3, stride=2, padding=1)

    def forward(self, x):
        return self.pool(self.conv(x))


class Conv2dBnHardtanhMean(torch.nn.Module):
    def __init__(self):
        super(Conv2dBnHardtanhMean, self).__init__()
        groups = 1
        stride = [2, 2]
        padding = [1, 1]
        dilation = [1, 1]
        in_channels = 1
        out_channels = 1

        self.conv = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=stride,
            padding=padding,
            groups=groups,
            dilation=dilation,
            bias=True,
        )
        self.conv.weight = torch.nn.Parameter(torch.randn(self.conv.weight.size()))
        self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
        self.eval()

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.native_batchnorm(x1)
        x3 = self.hardtanh(x2)
        x4 = torch.mean(x3, (1), keepdim=True)
        return x4


class Conv2dCat(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = torch.nn.Conv2d(3, 3, 3)
        self.conv2 = torch.nn.Conv2d(3, 3, 3)

    def forward(self, x, y):
        x = self.conv1(x)
        y = self.conv2(y)
        z = torch.cat([x, y], dim=1)
        return z


class Conv2dMaxPool2d(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=2,
            out_channels=2,
            kernel_size=(1, 1),
            padding=1,
            bias=True,
        )
        self.pool = torch.nn.MaxPool2d(1, 1)

    def forward(self, x):
        return self.pool(self.conv(x))


class Conv2dSequential(torch.nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.first = torch.nn.Conv2d(
            in_channels=1,
            out_channels=3,
            kernel_size=(3, 3),
            padding=1,
            bias=bias,
        )
        self.second = torch.nn.Conv2d(
            in_channels=3,
            out_channels=2,
            kernel_size=(3, 3),
            padding=1,
            bias=bias,
        )

    def forward(self, x):
        return self.second(self.first(x))


class Conv2dSingle(torch.nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=1,
            out_channels=3,
            kernel_size=(3, 3),
            padding=1,
            bias=bias,
        )

    def forward(self, x):
        return self.conv(x)


class ConvTranspose2dSingle(torch.nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.conv_transpose = torch.nn.ConvTranspose2d(
            in_channels=1,
            out_channels=3,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=bias,
        )

    def forward(self, x):
        return self.conv_transpose(x)


class Conv2dDownUpSample(torch.nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=16,
            out_channels=16,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=bias,
        )
        self.conv_transpose = torch.nn.ConvTranspose2d(
            in_channels=16,
            out_channels=16,
            kernel_size=3,
            stride=2,
            padding=1,
            bias=bias,
        )

    def forward(self, x):
        return self.conv_transpose(self.conv(x))


class Conv2dSumReduceDim(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first = torch.nn.Conv2d(
            in_channels=1,
            out_channels=3,
            kernel_size=(3, 3),
            padding=1,
            bias=True,
        )

    def forward(self, x):
        return torch.sum(self.first(x), dim=(2, 3), keepdim=False)


class Conv2dTopK(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3)

    def forward(self, x):
        x = self.conv(x)
        topk_values, topk_indices = torch.topk(x, 5, dim=1)
        return topk_values


class Div(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.divide(x, y)


class DivConstantFloat(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / 10.0


class DivConstantLong(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / 10


class EinsumBilinear(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, bn, anm, bm):
        return torch.einsum("bn,anm,bm->ba", bn, anm, bm)


class EinsumOuterProduct(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, i, j):
        return torch.einsum("i,j->ij", i, j)


class EinsumOuterProductRelu(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, i, j):
        return torch.relu(torch.einsum("i,j->ij", i, j))


class Embedding(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = torch.nn.Embedding(10, 3)

    def forward(self, x):
        return self.embedding(x)


class ExpandCopy(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.expand(3, 4)


class Gelu(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.gelu = torch.nn.GELU()

    def forward(self, x):
        return self.gelu(x)


class GroupNorm(torch.nn.Module):
    def __init__(self, bias=True):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            32,
            256,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=bias,
        )
        self.norm = torch.nn.GroupNorm(32, 256)

    def forward(self, x):
        y = self.conv(x)
        return y, self.norm(y)


class HardSigmoid(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hardsigmoid = torch.nn.Hardsigmoid()

    def forward(self, x):
        return self.hardsigmoid(x)


class HardSwish(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hardswish = torch.nn.Hardswish()

    def forward(self, x):
        return self.hardswish(x)


class HardTanh(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)

    def forward(self, x):
        return self.hardtanh(x)


class Index(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.idx0 = torch.tensor([[0, 1], [2, 3], [4, 5]], dtype=torch.int32)
        self.idx1 = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.int32)

    def forward(self, x):
        return x[self.idx0] + x[self.idx1]


class IndexPut(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.register_buffer(
            "k_cache",
            torch.zeros((1, 1024, 12, 64), dtype=torch.float32),
        )

    def forward(self, input_pos, k_val):
        k_out = torch.ops.aten.index_put_(self.k_cache, [None, input_pos], k_val)
        return k_out


class LayerNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
        self.linear = torch.nn.Linear(768, 196)

    def forward(self, x):
        return self.linear(self.layer_norm(x))


class LeakyReLUDefault(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.leaky_relu = torch.nn.LeakyReLU()

    def forward(self, x):
        return self.leaky_relu(x)


class LeakyReLUCustom(torch.nn.Module):
    def __init__(self, coeff):
        super().__init__()
        self.leaky_relu = torch.nn.LeakyReLU(coeff)

    def forward(self, x):
        return self.leaky_relu(x)


class Linear(torch.nn.Module):
    def __init__(self, use_bias: bool = True):
        super().__init__()
        self.linear = torch.nn.Linear(4, 5, use_bias).eval()

    def forward(self, x):
        return self.linear(x)


class LogSoftmax(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.nn.functional.log_softmax(x, dim=-1)


class MaxPool2d(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.max_pool2d = torch.nn.MaxPool2d(
            kernel_size=3,
            stride=1,
            padding=1,
            dilation=1,
            ceil_mode=True,
        )

    def forward(self, x):
        return self.max_pool2d(x)


class MeanWKeppDim(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.mean(x, (-1, -2), keepdim=True)


class MeanWOKeppDim(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.mean(x, (-1, -2))


class Mul(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.mul(x, y)


class MulConstantFloat(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 10.0 * x


class MulConstantLong(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 10 * x


class MulScalar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._scalar = 3.14

    def forward(self, x):
        out1 = torch.ops.aten.mul.Scalar(x, self._scalar)
        return out1


class MultiheadAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.multi_head_attention = torch.nn.MultiheadAttention(
            96, 12, dropout=0.0, batch_first=True
        )

    def forward(self, x):
        attn_output, _ = self.multi_head_attention(x, x, x, need_weights=False)
        return attn_output


class Pad(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.nn.functional.pad(
            x[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0, mode="constant"
        )


class PixelShuffle(torch.nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.pixel_shuffle = torch.nn.PixelShuffle(scale)

    def forward(self, x):
        return self.pixel_shuffle(x)


class PixelUnshuffle(torch.nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.pixel_unshuffle = torch.nn.PixelUnshuffle(scale)

    def forward(self, x):
        return self.pixel_unshuffle(x)


class PixelUnshuffleMathEquivalent(torch.nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, x):
        b, c, hh, hw = x.size()
        out_channel = c * (self.scale**2)
        h = hh // self.scale
        w = hw // self.scale
        x_view = x.view(b, c, h, self.scale, w, self.scale)
        return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)


class PowTensorScalar(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.pow(x, 2)


class PReLUDefault(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.prelu = torch.nn.PReLU()

    def forward(self, x):
        return self.prelu(x)


class PReLUPerChannel(torch.nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.prelu = torch.nn.PReLU(channels)

    def forward(self, x):
        return self.prelu(x)


class Relu(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.relu(x)


class Reshape(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.reshape(1, 12)


class ResidualBlockModule(torch.nn.Module):
    def __init__(self):
        super(ResidualBlockModule, self).__init__()
        groups = 1
        stride = [1, 1]
        padding = [1, 1]
        dilation = [1, 1]
        in_channels = 32
        out_channels = 32

        self.conv = torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=stride,
            padding=padding,
            groups=groups,
            dilation=dilation,
            bias=True,
        )
        self.native_batchnorm = torch.nn.BatchNorm2d(out_channels)
        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6.0)
        self.eval()

    def forward(self, x):
        x1 = self.conv(x)
        x2 = self.native_batchnorm(x1)
        x3 = self.conv(x2)
        x4 = self.native_batchnorm(x3)
        x5 = self.hardtanh(x4)
        x6 = torch.add(x5, x2)
        return x6


class ResizeBilinear2D(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        output_shape = [dim * 2 for dim in x.shape[-2:]]
        return torch.nn.functional.interpolate(
            x,
            size=list(torch.randn(output_shape).shape),
            mode="bilinear",
            align_corners=False,
        )


class ResizeNearest2D(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        output_shape = [dim * 2 for dim in x.shape[-2:]]
        return torch.nn.functional.interpolate(
            x,
            size=list(torch.randn(output_shape).shape),
            mode="nearest",
        )


class RmsNorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.eps = 1e-5
        self.rms = torch.nn.RMSNorm([4], 1e-5)

    def forward(self, x):
        return self.rms(x)


class Rsqrt(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.rsqrt(x)


class ScaledDotProductAttention(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, query_layer, key_layer, value_layer, attn_mask):
        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_layer, key_layer, value_layer, attn_mask
        )
        return attn_output


class SelectCopy(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3,
            out_channels=2,
            kernel_size=(3, 3),
            padding=1,
            bias=True,
        )

    def forward(self, x):
        return self.conv(x)[0, 1, 1:2]


class Sigmoid(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sigmoid(x)


class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        kernel_sz = 32
        self.conv1 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True)
        self.conv2 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=True)
        self.conv3 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False)
        self.conv4 = torch.nn.Conv2d(kernel_sz, kernel_sz, 3, padding=1, bias=False)
        self.hardtanh = torch.nn.Hardtanh(min_val=0, max_val=6)
        self.relu = torch.nn.ReLU()
        self.batch_norm = torch.nn.BatchNorm2d(kernel_sz)
        self.add = torch.add
        self.mean = torch.mean
        self.reshape = torch.reshape
        self.linear = torch.nn.Linear(4, 10)
        self.permute = torch.permute
        self.eval()

    def forward(self, x, y):
        x1 = self.conv1(x)
        x2 = self.batch_norm(x1)
        x3 = self.relu(x2)
        x4 = self.conv2(x3)
        x5 = self.relu(x4)
        y1 = self.conv3(y)
        y2 = self.batch_norm(y1)
        y3 = self.relu(y2)
        y4 = self.conv4(y3)
        y5 = self.relu(y4)
        z = self.add(x5, y5)
        z1 = self.permute(z, (0, 3, 2, 1))
        z2 = torch.mean(z1, [1, 2], True)
        z3 = self.reshape(z2, (8, -1))
        z4 = self.linear(z3)
        z5 = self.hardtanh(z4)
        return z5


class SliceCopy(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.position_ids = torch.randn([1, 512])

    def forward(self, x, y):
        seq_length = y.size()[1]
        return x[:, :seq_length] + self.position_ids[:, :seq_length]


class SliceCopyWithStep(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.position_ids = torch.randn([1, 512])
        self.step = 2

    def forward(self, x, y):
        seq_length = y.size()[1]
        return (
            x[:, : seq_length : self.step]
            + self.position_ids[:, : seq_length : self.step]
        )


class Softmax(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.nn.functional.softmax(x, dim=-1)


class Sqrt(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sqrt(x)


class SqrtConstant(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x / torch.sqrt(torch.tensor([64.0]))


class Squeeze(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.squeeze()


class Stack(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.stack((x, y))


class Sub(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x, y):
        return torch.sub(x, y)


class SubConstantFloat(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 10.0 - x


class SubConstantLong(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 10 - x


class SumIntList(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.sum(x, dim=(2, 3), keepdim=True)


class Tanh(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.tanh(x)


class TopKandIndex(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.idx_source = torch.rand(10, 3)

    def forward(self, x):
        a, b = torch.topk(x, 3)
        return a + self.idx_source[b]


class Unbind(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.unbind(x)


class Unsqueeze(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.unsqueeze(0)


class View(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first_size = 2
        self.second_size = 256

    def forward(self, x, y):
        new_shape = x.size()[:-1] + (self.first_size, self.second_size)
        return x.view(new_shape)


class ViewPermuteMatMul(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.first_size = 2
        self.second_size = 256

    def forward(self, x, y):
        new_shape = x.size()[:-1] + (self.first_size, self.second_size)
        x = x.view(new_shape)
        x = x.permute(0, 2, 1, 3)
        return torch.matmul(x, y.transpose(-1, -2))
