import math

from torch import nn
from torch.nn import init


def _initialize_orthogonal(conv):
    prelu_gain = math.sqrt(2)
    init.orthogonal(conv.weight, gain=prelu_gain)
    if conv.bias is not None:
        conv.bias.data.zero_()


class ResidualBlock(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.conv1 = nn.Conv2d(
            n_filters, n_filters, kernel_size=3, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(n_filters)
        self.prelu = nn.PReLU(n_filters)
        self.conv2 = nn.Conv2d(
            n_filters, n_filters, kernel_size=3, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(n_filters)

        # Orthogonal initialisation
        _initialize_orthogonal(self.conv1)
        _initialize_orthogonal(self.conv2)

    def forward(self, x):
        residual = self.prelu(self.bn1(self.conv1(x)))
        residual = self.bn2(self.conv2(residual))
        return x + residual


class UpscaleBlock(nn.Module):
    def __init__(self, n_filters):
        super().__init__()
        self.upscaling_conv = nn.Conv2d(
            n_filters, 4 * n_filters, kernel_size=3, padding=1
        )
        self.upscaling_shuffler = nn.PixelShuffle(2)
        self.upscaling = nn.PReLU(n_filters)
        _initialize_orthogonal(self.upscaling_conv)

    def forward(self, x):
        return self.upscaling(self.upscaling_shuffler(self.upscaling_conv(x)))


class SRResNet(nn.Module):
    def __init__(self, rescale_factor, n_filters, n_blocks):
        super().__init__()
        self.rescale_levels = int(math.log(rescale_factor, 2))  # noqa: FURB163
        self.n_filters = n_filters
        self.n_blocks = n_blocks

        self.conv1 = nn.Conv2d(3, n_filters, kernel_size=9, padding=4)
        self.prelu1 = nn.PReLU(n_filters)

        for residual_block_num in range(1, n_blocks + 1):
            residual_block = ResidualBlock(self.n_filters)
            self.add_module(
                "residual_block" + str(residual_block_num),
                nn.Sequential(residual_block),
            )

        self.skip_conv = nn.Conv2d(
            n_filters, n_filters, kernel_size=3, padding=1, bias=False
        )
        self.skip_bn = nn.BatchNorm2d(n_filters)

        for upscale_block_num in range(1, self.rescale_levels + 1):
            upscale_block = UpscaleBlock(self.n_filters)
            self.add_module(
                "upscale_block" + str(upscale_block_num), nn.Sequential(upscale_block)
            )

        self.output_conv = nn.Conv2d(n_filters, 3, kernel_size=9, padding=4)

        # Orthogonal initialisation
        _initialize_orthogonal(self.conv1)
        _initialize_orthogonal(self.skip_conv)
        _initialize_orthogonal(self.output_conv)

    def forward(self, x):
        x_init = self.prelu1(self.conv1(x))
        x = self.residual_block1(x_init)
        for residual_block_num in range(2, self.n_blocks + 1):
            x = getattr(self, "residual_block" + str(residual_block_num))(x)
        x = self.skip_bn(self.skip_conv(x)) + x_init
        for upscale_block_num in range(1, self.rescale_levels + 1):
            x = getattr(self, "upscale_block" + str(upscale_block_num))(x)
        return self.output_conv(x)
