# Owner(s): ["module: nn"]
import itertools
import random
import unittest
from itertools import product

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import (
    expectedFailureXLA,
    instantiate_device_type_tests,
)
from torch.testing._internal.common_nn import freeze_rng_state, NNTestCase
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    run_tests,
    set_default_dtype,
    TEST_PRIVATEUSE1,
)


class TestDropoutNN(NNTestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True

    def _test_alpha_dropout(self, cls, input):
        mean = input.mean()
        std = input.std()

        for p in [0.2, 0.5, 0.8]:
            module = cls(p)
            input_var = input.detach().clone().requires_grad_()
            output = module(input_var)
            # output mean should be close to input mean
            self.assertLess(abs(output.data.mean() - mean), 0.1)
            # output std should be close to input std
            self.assertLess(abs(output.data.std() - std), 0.1)
            output.backward(input)

    def test_AlphaDropout(self):
        # generate random tensor with zero mean and unit std
        input = torch.randn(5000)
        self._test_alpha_dropout(nn.AlphaDropout, input)

    def test_FeatureAlphaDropout(self):
        b = random.randint(1, 5)
        w = random.randint(1, 5)
        h = random.randint(1, 5)
        d = random.randint(1, 2)
        num_features = 1000
        input = torch.randn(num_features, b, d, w, h)
        self._test_alpha_dropout(nn.FeatureAlphaDropout, input)

        # no batch dims
        input = torch.randn(50, 20, 64, 64)
        self._test_alpha_dropout(nn.FeatureAlphaDropout, input)

    @unittest.skipIf(
        not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 unavailable"
    )
    def test_native_dropout_corner_case(self):
        if TEST_CUDA:
            device = "cuda"
        elif TEST_PRIVATEUSE1:
            device = torch._C._get_privateuse1_backend_name()
        for train in [True, False]:
            for p in [0.0, 1.0]:
                for current_device in [device, "cpu"]:
                    x = torch.randn(5).to(device=current_device).requires_grad_()
                    x_ref = x.detach().requires_grad_()
                    o = torch.native_dropout(x, p, train)[0]
                    o_ref = torch.dropout(x_ref, p, train)
                    o.sum().backward()
                    o_ref.sum().backward()
                    assert o.equal(o_ref)
                    assert x.grad.equal(x_ref.grad)

    def test_invalid_dropout_p(self):
        v = torch.ones(1)
        self.assertRaises(ValueError, lambda: nn.Dropout(-0.1))
        self.assertRaises(ValueError, lambda: nn.Dropout(1.1))
        self.assertRaises(ValueError, lambda: nn.Dropout1d(-0.1))
        self.assertRaises(ValueError, lambda: nn.Dropout1d(1.1))
        self.assertRaises(ValueError, lambda: nn.Dropout2d(-0.1))
        self.assertRaises(ValueError, lambda: nn.Dropout2d(1.1))
        self.assertRaises(ValueError, lambda: nn.Dropout3d(-0.1))
        self.assertRaises(ValueError, lambda: nn.Dropout3d(1.1))
        self.assertRaises(ValueError, lambda: F.dropout(v, -0.1))
        self.assertRaises(ValueError, lambda: F.dropout(v, 1.1))


class TestDropoutNNDeviceType(NNTestCase):
    def _test_dropout(self, cls, device, input, memory_format=torch.contiguous_format):
        p = 0.2
        input = input.to(device).fill_(1 - p)

        module = cls(p)
        input_var = input.clone(memory_format=memory_format).requires_grad_()
        output = module(input_var)
        self.assertTrue(output.is_contiguous(memory_format=memory_format))
        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
        output.backward(input)
        self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)

        module = cls(p, True)
        input_var = input.clone(memory_format=memory_format).requires_grad_()
        output = module(input_var + 0)
        self.assertTrue(output.is_contiguous(memory_format=memory_format))
        self.assertLess(abs(output.data.mean() - (1 - p)), 0.05)
        output.backward(input)
        self.assertTrue(input_var.grad.is_contiguous(memory_format=memory_format))
        self.assertLess(abs(input_var.grad.data.mean() - (1 - p)), 0.05)

        # check eval mode doesn't change anything
        for inplace in [True, False]:
            module = cls(p, inplace).eval()
            self.assertEqual(input, module(input))

        # Check that these don't raise errors
        module.__repr__()
        str(module)

    def _test_dropout_discontiguous(
        self, cls, device, memory_format=torch.contiguous_format
    ):
        # In this test, we verify that dropout preserves the layout and data for different memory formats.
        # We check whether, we get same values for the output of dropout, when the probability
        # of dropout is 0 or very close to 0.
        # Reference: https://github.com/pytorch/pytorch/issues/47176
        close_to_zero_p = 1e-10  # Should be almost zero but not zero, as for p=0 different path is taken
        for p in [0, close_to_zero_p]:
            inp = torch.ones(2, 3, 3, 3, device=device)
            inp_discontiguous = torch.empty(
                2, 3, 3, 6, device=device, memory_format=memory_format
            )[..., ::2]
            inp_discontiguous.copy_(inp)
            mod = cls(p=p)
            out = mod(inp_discontiguous)
            if p != 0:  # Zero will keep strides as is based on input.
                # When prob == 0, input stride (54, 18, 6, 2) -> output stride (54, 18, 6, 2)
                # When prob != 0, input stride (54, 18, 6, 2) -> output stride (27, 9, 3, 1)
                self.assertTrue(out.is_contiguous(memory_format=memory_format))
            self.assertEqual(inp_discontiguous, out)

    def _test_dropout_stride_mean_preserve(self, cls, device):
        def invert_perm(p):
            d = {x: i for i, x in enumerate(p)}
            return (d[0], d[1], d[2], d[3])

        inp = torch.ones(2, 3, 4, 5, device=device)
        shifts = [(0, 0), (1, 0), (0, 1), (1, 1)]
        for perm in itertools.permutations((0, 1, 2, 3), r=4):
            for shift in shifts:
                for p in [1e-10, 0.3, 0.5, 0.7]:
                    mod = cls(p=p)
                    permuted_inp = (
                        inp.permute(perm).contiguous().permute(invert_perm(perm))
                    )
                    permuted_inp = permuted_inp[shift[0] :, shift[1] :, :, :]
                    out = mod(permuted_inp)

                    self.assertTrue(out.permute(perm).is_contiguous())
                    self.assertEqual(inp.mean(), out.mean(), rtol=0.5, atol=0.5)
                    if p == 1e-10:
                        self.assertEqual(permuted_inp, out)
                    else:
                        self.assertNotEqual(permuted_inp, out)

    def test_Dropout(self, device):
        input = torch.empty(1000)
        self._test_dropout(nn.Dropout, device, input)

        self._test_dropout_discontiguous(nn.Dropout, device)
        self._test_dropout_discontiguous(
            nn.Dropout, device, memory_format=torch.channels_last
        )

        self._test_dropout_stride_mean_preserve(nn.Dropout, device)

        if self.device_type == "cuda" or self.device_type == "cpu":
            input = input.bfloat16()
            self._test_dropout(nn.Dropout, device, input)

    def _test_dropoutNd_no_batch(self, dropout, input):
        input_clone = input.clone()
        with freeze_rng_state():
            res_no_batch = dropout(input)

        with freeze_rng_state():
            res_batched = dropout(input_clone.unsqueeze(0)).squeeze(0)

        self.assertEqual(res_no_batch, res_batched)

    def _test_dropoutNd_channel_zero(self, dropout, input):
        # Verify the number of zeros in a channel is 0 or the number of elements in the channel
        # for a fully positive input tensor
        shape = input.shape
        B = shape[0]
        C = shape[1]
        channel_numel = torch.tensor(shape[2:]).prod()
        result = dropout(input)

        for b, c in product(range(B), range(C)):
            self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel))

    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
    def test_Dropout1d(self, device):
        with set_default_dtype(torch.double):
            N, C, L = (
                random.randint(10, 15),
                random.randint(10, 15),
                random.randint(10, 15),
            )
            input = torch.empty(N, C, L)
            self._test_dropout(nn.Dropout1d, device, input)

            with self.assertRaisesRegex(
                RuntimeError, "Expected 2D or 3D input, but received a 4D input"
            ):
                nn.Dropout1d(p=0.5)(torch.rand(1, 2, 2, 2, device=device))

            with self.assertRaisesRegex(
                RuntimeError, "Expected 2D or 3D input, but received a 1D input"
            ):
                nn.Dropout1d(p=0.5)(torch.rand(2, device=device))

            # no batch dims
            input = torch.rand(50, 2, device=device)
            self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5), input)
            self._test_dropoutNd_no_batch(nn.Dropout1d(p=0.5, inplace=True), input)

            # check that complete channels are dropped
            input = torch.ones(10, 4, 2, device=device)
            self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5), input)
            self._test_dropoutNd_channel_zero(nn.Dropout1d(p=0.5, inplace=True), input)

    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
    def test_Dropout2d(self, device):
        b = random.randint(1, 5)
        w = random.randint(1, 5)
        h = random.randint(1, 5)
        num_features = 1000
        input = torch.empty(num_features, b, w, h)
        self._test_dropout(nn.Dropout2d, device, input)
        self._test_dropout(
            nn.Dropout2d, device, input, memory_format=torch.channels_last
        )

        self._test_dropout_discontiguous(nn.Dropout2d, device)
        self._test_dropout_discontiguous(
            nn.Dropout2d, device, memory_format=torch.channels_last
        )

        with self.assertWarnsRegex(UserWarning, "Received a 5-D input to dropout2d"):
            nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, 2, 2, device=device))

        with self.assertWarnsRegex(UserWarning, "Received a 2-D input to dropout2d"):
            nn.Dropout2d(p=0.5)(torch.rand(1, 2, device=device))

        # TODO: Uncomment these lines once no-batch-dim inputs are supported.
        # For now, the historical dropout1d behavior is performed for 3D inputs.
        # See https://github.com/pytorch/pytorch/issues/77081

        # input = torch.rand(50, 2, 2, device=device)
        # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5), input)
        # self._test_dropoutNd_no_batch(nn.Dropout2d(p=0.5, inplace=True), input)

        with self.assertWarnsRegex(
            UserWarning, "assuming that channel-wise 1D dropout behavior is desired"
        ):
            nn.Dropout2d(p=0.5)(torch.rand(1, 2, 2, device=device))

        # check that complete channels are dropped
        input = torch.ones(10, 4, 2, 2, device=device)
        self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5), input)
        self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input)

    @expectedFailureXLA  # seems like freeze_rng_state is not honoured by XLA
    def test_Dropout3d(self, device):
        b = random.randint(1, 5)
        w = random.randint(1, 5)
        h = random.randint(1, 5)
        d = random.randint(1, 2)
        num_features = 1000
        input = torch.empty(num_features, b, d, w, h)
        self._test_dropout(nn.Dropout3d, device, input)

        self._test_dropout_discontiguous(nn.Dropout3d, device)
        self._test_dropout_discontiguous(
            nn.Dropout3d, device, memory_format=torch.channels_last
        )

        with self.assertWarnsRegex(UserWarning, "Received a 6-D input to dropout3d"):
            nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, 2, 2, 2, device=device))

        with self.assertWarnsRegex(UserWarning, "Received a 3-D input to dropout3d"):
            nn.Dropout3d(p=0.5)(torch.rand(1, 2, 2, device=device))

        # no batch dims
        input = torch.rand(50, 2, 2, 2, device=device)
        self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5), input)
        self._test_dropoutNd_no_batch(nn.Dropout3d(p=0.5, inplace=True), input)

        # check that complete channels are dropped
        input = torch.ones(10, 4, 2, 2, 2, device=device)
        self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5), input)
        self._test_dropoutNd_channel_zero(nn.Dropout3d(p=0.5, inplace=True), input)

    def test_empty_dropout(self, device):
        x = torch.tensor([]).to(device)
        out = torch.nn.functional.dropout(x)
        self.assertEqual(out.size(), x.size())


instantiate_device_type_tests(TestDropoutNNDeviceType, globals())
instantiate_parametrized_tests(TestDropoutNN)

if __name__ == "__main__":
    run_tests()
