# Owner(s): ["module: nn"]
import pickle
import unittest

import torch
import torch.nn as nn
from torch.nn import Buffer, Parameter
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import (
    run_tests,
    suppress_warnings,
    TEST_PRIVATEUSE1,
    TestCase,
)


class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
    pass


class TestLazyModules(TestCase):
    @suppress_warnings
    def test_lazy_module_parameter(self):
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        self.assertTrue(module.has_uninitialized_params())
        state_dict = module.state_dict()
        self.assertIsInstance(state_dict["test_param"], UninitializedParameter)
        new_module = LazyModule()
        # An error is raised when there is an attempt to replace an existing parameter
        # with an uninitialized one
        new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
            new_module.load_state_dict(state_dict)
        # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
        new_module = LazyModule()
        new_module.register_parameter("test_param", nn.Parameter(torch.ones(5, 5)))
        module.load_state_dict(new_module.state_dict())
        self.assertEqual(module.test_param, torch.ones((5, 5)))

        # Uninitialized parameters are left unchanged
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        self.assertTrue(module.has_uninitialized_params())

        new_module = LazyModule()
        new_module.register_parameter("test_param", UninitializedParameter())
        module.load_state_dict(new_module.state_dict())
        self.assertTrue(module.has_uninitialized_params())

    @suppress_warnings
    def test_lazy_module_buffer(self):
        module = LazyModule()
        module.test_buffer = UninitializedBuffer()
        self.assertTrue(module.has_uninitialized_params())
        state_dict = module.state_dict()
        self.assertIsInstance(state_dict["test_buffer"], UninitializedBuffer)
        new_module = LazyModule()
        # An error is raised when there is an attempt to replace an existing parameter
        # with an uninitialized one
        new_module.test_buffer = Buffer(torch.ones(5, 5))
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
            new_module.load_state_dict(state_dict)
        # Uninitialized parameters are overriden when the state dict to be loaded contains a valid one
        new_module = LazyModule()
        new_module.test_buffer = Buffer(torch.ones(5, 5))
        module.load_state_dict(new_module.state_dict())
        self.assertEqual(module.test_buffer, torch.ones((5, 5)))

        # Uninitialized parameters are left unchanged
        module = LazyModule()
        module.test_buffer = UninitializedBuffer()
        self.assertTrue(module.has_uninitialized_params())

        new_module = LazyModule()
        new_module.test_buffer = UninitializedBuffer()
        module.load_state_dict(new_module.state_dict())
        module.load_state_dict(new_module.state_dict())
        self.assertTrue(module.has_uninitialized_params())

    @suppress_warnings
    def test_lazy_module_jit_param(self):
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        self.assertTrue(module.has_uninitialized_params())
        with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
            torch.jit.script(module)

    @suppress_warnings
    def test_lazy_module_jit_buffer(self):
        module = LazyModule()
        module.test_buffer = UninitializedBuffer()
        self.assertTrue(module.has_uninitialized_params())
        with self.assertRaisesRegex(RuntimeError, "run a forward pass"):
            torch.jit.script(module)

    @suppress_warnings
    def test_lazy_share_memory_param(self):
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        self.assertTrue(module.has_uninitialized_params())
        with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
            module.share_memory()

    @suppress_warnings
    def test_lazy_share_memory_buffer(self):
        module = LazyModule()
        module.test_buffer = UninitializedBuffer()
        self.assertTrue(module.has_uninitialized_params())
        with self.assertRaisesRegex(RuntimeError, "share memory on an uninitialized"):
            module.share_memory()

    @suppress_warnings
    def test_linear(self):
        module = nn.LazyLinear(10)
        self.assertIsInstance(module.weight, UninitializedParameter)
        self.assertIsInstance(module.bias, UninitializedParameter)
        input = torch.ones(5, 5)
        module(input)
        self.assertIsInstance(module, nn.Linear)
        self.assertNotIsInstance(module, nn.LazyLinear)
        self.assertTrue(module.weight.shape == (10, 5))
        self.assertTrue(module.bias.shape == (10,))
        y = module(input)
        self.assertTrue(
            torch.equal(
                torch.nn.functional.linear(input, module.weight, module.bias), y
            )
        )

    @suppress_warnings
    def test_lazy_linear_pickle(self):
        module = nn.LazyLinear(10)
        self.assertIsInstance(module.weight, UninitializedParameter)
        self.assertIsInstance(module.bias, UninitializedParameter)
        module = pickle.loads(pickle.dumps(module))
        self.assertIsInstance(module, nn.LazyLinear)
        self.assertIsInstance(module.weight, UninitializedParameter)
        self.assertIsInstance(module.bias, UninitializedParameter)
        input = torch.ones(5, 5)
        module(input)  # fully materialized
        new_module = pickle.loads(pickle.dumps(module))
        self.assertIsInstance(new_module, nn.Linear)
        self.assertNotIsInstance(new_module, nn.LazyLinear)
        self.assertTrue(new_module.weight.shape == (10, 5))
        self.assertNotIsInstance(new_module.weight, UninitializedParameter)
        self.assertTrue(new_module.bias.shape == (10,))
        self.assertNotIsInstance(new_module.bias, UninitializedParameter)

    @suppress_warnings
    def test_linear_state(self):
        module = nn.Linear(5, 10)
        lazy_module = nn.LazyLinear(10)
        lazy_module.load_state_dict(module.state_dict())
        # Parameters have been initialized but the module won't become a full
        # Linear one until the first iteration. This is due to
        # limitations on the state_dict loading logic
        self.assertFalse(lazy_module.has_uninitialized_params())
        self.assertTrue(lazy_module.weight.shape == (10, 5))
        self.assertTrue(lazy_module.bias.shape == (10,))

        module = nn.Linear(5, 10)
        lazy_module = nn.LazyLinear(10)
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
            module.load_state_dict(lazy_module.state_dict())

    def _check_lazy_conv(
        self,
        cls,
        lazy_cls,
        func,
        init_args,
        input_shape,
        expected_weight_shape,
        expected_bias_shape,
        *forward_args,
        **forward_kwargs,
    ):
        module = lazy_cls(*init_args)
        self.assertIsInstance(module.weight, UninitializedParameter)
        if module.bias is not None:
            self.assertIsInstance(module.bias, UninitializedParameter)
        input = torch.ones(*input_shape)
        module(input, *forward_args, **forward_kwargs)
        self.assertIsInstance(module, cls)
        self.assertNotIsInstance(module, lazy_cls)
        self.assertEqual(module.weight.shape, expected_weight_shape)
        if module.bias is not None:
            self.assertEqual(module.bias.shape, expected_bias_shape)
        y = module(input)
        self.assertTrue(torch.equal(func(input, module.weight, module.bias), y))

    def _check_lazy_conv_pickle(
        self,
        cls,
        lazy_cls,
        init_args,
        input_shape,
        expected_weight_shape,
        expected_bias_shape,
    ):
        module = lazy_cls(*init_args)
        self.assertIsInstance(module.weight, UninitializedParameter)
        if module.bias is not None:
            self.assertIsInstance(module.bias, UninitializedParameter)
        module = pickle.loads(pickle.dumps(module))
        self.assertIsInstance(module, lazy_cls)
        self.assertIsInstance(module.weight, UninitializedParameter)
        if module.bias is not None:
            self.assertIsInstance(module.bias, UninitializedParameter)
        input = torch.ones(*input_shape)
        module(input)  # fully materialized
        new_module = pickle.loads(pickle.dumps(module))
        self.assertIsInstance(new_module, cls)
        self.assertNotIsInstance(new_module, lazy_cls)
        self.assertEqual(new_module.weight.shape, expected_weight_shape)
        self.assertNotIsInstance(new_module.weight, UninitializedParameter)
        if new_module.bias is not None:
            self.assertEqual(new_module.bias.shape, expected_bias_shape)
            self.assertNotIsInstance(new_module.bias, UninitializedParameter)

    def _check_lazy_conv_state(
        self, gen_module, gen_lazy_module, expected_weight_shape, expected_bias_shape
    ):
        module = gen_module()
        lazy_module = gen_lazy_module()
        lazy_module.load_state_dict(module.state_dict())
        # Parameters have been initialized but the module won't become a full
        # Conv one until the first iteration. This is due to
        # limitations on the state_dict loading logic
        self.assertFalse(lazy_module.has_uninitialized_params())
        self.assertEqual(lazy_module.weight.shape, expected_weight_shape)
        if lazy_module.bias is not None:
            self.assertEqual(lazy_module.bias.shape, expected_bias_shape)

        module = gen_module()
        lazy_module = gen_lazy_module()
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
            module.load_state_dict(lazy_module.state_dict())

    def test_lazy_pre_forward_hook(self):
        """
        This test is to test whether lazymodule can register other pre-forward hook
        functions successfully.
        """

        class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
            def initialize_parameters(self, input):
                return None

            def forward(self, input):
                return input

        def hook_function(module, input):
            return input[0] + 1

        module = TestModule()
        module.register_forward_pre_hook(hook_function)
        output = module(torch.zeros(2, 2))
        self.assertEqual(output, torch.ones(2, 2))

    def test_lazy_forward_hook(self):
        """
        This test is to test whether lazymodule can register other forward hook
        functions successfully.
        """

        class TestModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module):
            def initialize_parameters(self, input):
                return None

            def forward(self, input):
                return input

        def hook_function(module, input, output):
            return input[0] + 1

        module = TestModule()
        module.register_forward_hook(hook_function)
        output = module(torch.zeros(2, 2))
        self.assertEqual(output, torch.ones(2, 2))

    @suppress_warnings
    def test_lazy_conv1d(self):
        self._check_lazy_conv(
            nn.Conv1d,
            nn.LazyConv1d,
            torch.nn.functional.conv1d,
            (32, 2),
            (192, 16, 50),
            (32, 16, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv1d_pickle(self):
        self._check_lazy_conv_pickle(
            nn.Conv1d, nn.LazyConv1d, (32, 2), (192, 16, 50), (32, 16, 2), (32,)
        )

    @suppress_warnings
    def test_lazy_conv1d_state(self):
        self._check_lazy_conv_state(
            lambda: nn.Conv1d(16, 32, 2),
            lambda: nn.LazyConv1d(32, 2),
            (32, 16, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv2d(self):
        self._check_lazy_conv(
            nn.Conv2d,
            nn.LazyConv2d,
            torch.nn.functional.conv2d,
            (32, 2),
            (192, 16, 8, 6),
            (32, 16, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv2d_pickle(self):
        self._check_lazy_conv_pickle(
            nn.Conv2d, nn.LazyConv2d, (32, 2), (192, 16, 8, 6), (32, 16, 2, 2), (32,)
        )

    @suppress_warnings
    def test_lazy_conv2d_state(self):
        self._check_lazy_conv_state(
            lambda: nn.Conv2d(16, 32, 2),
            lambda: nn.LazyConv2d(32, 2),
            (32, 16, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv3d(self):
        self._check_lazy_conv(
            nn.Conv3d,
            nn.LazyConv3d,
            torch.nn.functional.conv3d,
            (32, 2),
            (192, 16, 8, 7, 6),
            (32, 16, 2, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv3d_pickle(self):
        self._check_lazy_conv_pickle(
            nn.Conv3d,
            nn.LazyConv3d,
            (32, 2),
            (192, 16, 8, 7, 6),
            (32, 16, 2, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv3d_state(self):
        self._check_lazy_conv_state(
            lambda: nn.Conv3d(16, 32, 2),
            lambda: nn.LazyConv3d(32, 2),
            (32, 16, 2, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transposed1d(self):
        self._check_lazy_conv(
            nn.ConvTranspose1d,
            nn.LazyConvTranspose1d,
            torch.nn.functional.conv_transpose1d,
            (32, 2),
            (192, 16, 50),
            (16, 32, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose1d_kwargs(self):
        self._check_lazy_conv(
            nn.ConvTranspose1d,
            nn.LazyConvTranspose1d,
            torch.nn.functional.conv_transpose1d,
            (32, 2),
            (192, 16, 50),
            (16, 32, 2),
            (32,),
            output_size=(51,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose1d_pickle(self):
        self._check_lazy_conv_pickle(
            nn.ConvTranspose1d,
            nn.LazyConvTranspose1d,
            (32, 2),
            (192, 16, 50),
            (16, 32, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose1d_state(self):
        self._check_lazy_conv_state(
            lambda: nn.ConvTranspose1d(16, 32, 2),
            lambda: nn.LazyConvTranspose1d(32, 2),
            (16, 32, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose2d(self):
        self._check_lazy_conv(
            nn.ConvTranspose2d,
            nn.LazyConvTranspose2d,
            torch.nn.functional.conv_transpose2d,
            (32, 2),
            (192, 16, 8, 6),
            (16, 32, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose2d_kwargs(self):
        self._check_lazy_conv(
            nn.ConvTranspose2d,
            nn.LazyConvTranspose2d,
            torch.nn.functional.conv_transpose2d,
            (32, 2),
            (192, 16, 8, 6),
            (16, 32, 2, 2),
            (32,),
            output_size=(9, 7),
        )

    @suppress_warnings
    def test_lazy_conv_transpose2d_pickle(self):
        self._check_lazy_conv_pickle(
            nn.ConvTranspose2d,
            nn.LazyConvTranspose2d,
            (32, 2),
            (192, 16, 8, 6),
            (16, 32, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose2d_state(self):
        self._check_lazy_conv_state(
            lambda: nn.ConvTranspose2d(16, 32, 2),
            lambda: nn.LazyConvTranspose2d(32, 2),
            (16, 32, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose3d(self):
        self._check_lazy_conv(
            nn.ConvTranspose3d,
            nn.LazyConvTranspose3d,
            torch.nn.functional.conv_transpose3d,
            (32, 2),
            (192, 16, 8, 7, 6),
            (16, 32, 2, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose3d_kwargs(self):
        self._check_lazy_conv(
            nn.ConvTranspose3d,
            nn.LazyConvTranspose3d,
            torch.nn.functional.conv_transpose3d,
            (32, 2),
            (192, 16, 8, 7, 6),
            (16, 32, 2, 2, 2),
            (32,),
            output_size=(9, 8, 7),
        )

    @suppress_warnings
    def test_lazy_conv_transpose3d_pickle(self):
        self._check_lazy_conv_pickle(
            nn.ConvTranspose3d,
            nn.LazyConvTranspose3d,
            (32, 2),
            (192, 16, 8, 7, 6),
            (16, 32, 2, 2, 2),
            (32,),
        )

    @suppress_warnings
    def test_lazy_conv_transpose3d_state(self):
        self._check_lazy_conv_state(
            lambda: nn.ConvTranspose3d(16, 32, 2),
            lambda: nn.LazyConvTranspose3d(32, 2),
            (16, 32, 2, 2, 2),
            (32,),
        )

    def _check_lazy_norm(self, cls, lazy_cls, input_shape):
        for affine in [False, True]:
            for track_running_stats in [False, True]:
                lazy_module = lazy_cls(
                    affine=affine, track_running_stats=track_running_stats
                )

                if affine:
                    self.assertIsInstance(lazy_module.weight, UninitializedParameter)
                    self.assertIsInstance(lazy_module.bias, UninitializedParameter)
                if track_running_stats:
                    self.assertIsInstance(lazy_module.running_mean, UninitializedBuffer)
                    self.assertIsInstance(lazy_module.running_var, UninitializedBuffer)

                input = torch.ones(*input_shape)
                lazy_output = lazy_module(input)
                self.assertIsInstance(lazy_module, cls)
                self.assertNotIsInstance(lazy_module, lazy_cls)

                num_features = input_shape[1]
                module = cls(
                    num_features, affine=affine, track_running_stats=track_running_stats
                )
                expected_output = module(input)

                self.assertEqual(lazy_output, expected_output)
                if module.weight is not None:
                    self.assertEqual(lazy_module.weight.shape, module.weight.shape)
                    self.assertEqual(lazy_module.weight, module.weight)
                if module.bias is not None:
                    self.assertEqual(lazy_module.bias.shape, module.bias.shape)
                    self.assertEqual(lazy_module.bias, module.bias)
                if module.running_mean is not None:
                    self.assertEqual(
                        lazy_module.running_mean.shape, module.running_mean.shape
                    )
                    self.assertEqual(lazy_module.running_mean, module.running_mean)
                if module.running_var is not None:
                    self.assertEqual(
                        lazy_module.running_var.shape, module.running_var.shape
                    )
                    self.assertEqual(lazy_module.running_var, module.running_var)
                if module.num_batches_tracked is not None:
                    self.assertEqual(
                        lazy_module.num_batches_tracked.shape,
                        module.num_batches_tracked.shape,
                    )
                    self.assertEqual(
                        lazy_module.num_batches_tracked, module.num_batches_tracked
                    )

    def _check_lazy_norm_pickle(self, cls, lazy_cls, input_shape):
        for affine in [False, True]:
            for track_running_stats in [False, True]:
                module = lazy_cls(
                    affine=affine, track_running_stats=track_running_stats
                )
                module = pickle.loads(pickle.dumps(module))

                self.assertIsInstance(module, lazy_cls)
                if affine:
                    self.assertIsInstance(module.weight, UninitializedParameter)
                    self.assertIsInstance(module.bias, UninitializedParameter)
                if track_running_stats:
                    self.assertIsInstance(module.running_mean, UninitializedBuffer)
                    self.assertIsInstance(module.running_var, UninitializedBuffer)

                input = torch.ones(*input_shape)
                module(input)  # fully materialized
                module = pickle.loads(pickle.dumps(module))

                self.assertNotIsInstance(module, lazy_cls)
                self.assertIsInstance(module, cls)
                if affine:
                    self.assertNotIsInstance(module.weight, UninitializedParameter)
                    self.assertNotIsInstance(module.bias, UninitializedParameter)
                if track_running_stats:
                    self.assertNotIsInstance(module.running_mean, UninitializedBuffer)
                    self.assertNotIsInstance(module.running_var, UninitializedBuffer)

    def _check_lazy_batchnorm_state(self, cls, lazy_cls):
        module = cls(10)
        lazy_module = lazy_cls(affine=True, track_running_stats=True)
        lazy_module.load_state_dict(module.state_dict())
        # Parameters have been initialized but the module won't become a full
        # Conv one until the first iteration. This is due to
        # limitations on the state_dict loading logic
        self.assertFalse(lazy_module.has_uninitialized_params())
        self.assertEqual(lazy_module.weight.shape, (10,))
        self.assertEqual(lazy_module.bias.shape, (10,))
        self.assertEqual(lazy_module.running_mean.shape, (10,))
        self.assertEqual(lazy_module.running_var.shape, (10,))

        module = cls(10)
        lazy_module = lazy_cls()
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
            module.load_state_dict(lazy_module.state_dict())

    def _check_lazy_instancenorm_state(self, cls, lazy_cls):
        for affine in [False, True]:
            for track_running_stats in [False, True]:
                module = cls(10, affine=affine, track_running_stats=track_running_stats)
                lazy_module = lazy_cls(
                    affine=affine, track_running_stats=track_running_stats
                )
                lazy_module.load_state_dict(module.state_dict())
                # Parameters have been initialized but the module won't become a full
                # InstanceNorm one until the first iteration. This is due to
                # limitations on the state_dict loading logic
                self.assertFalse(lazy_module.has_uninitialized_params())
                if affine:
                    self.assertEqual(lazy_module.weight.shape, (10,))
                    self.assertEqual(lazy_module.bias.shape, (10,))
                if track_running_stats:
                    self.assertEqual(lazy_module.running_mean.shape, (10,))
                    self.assertEqual(lazy_module.running_var.shape, (10,))

        module = cls(10, affine=True, track_running_stats=True)
        lazy_module = lazy_cls(affine=True, track_running_stats=True)
        with self.assertRaisesRegex(RuntimeError, "shape of an uninitialized"):
            module.load_state_dict(lazy_module.state_dict())

    def _check_lazy_norm_with_dict_input(self, cls, lazy_cls, input_shape):
        input = {"input": torch.ones(*input_shape)}

        lazy_module = lazy_cls()
        lazy_output = lazy_module(**input)

        num_features = input_shape[1]
        module = cls(num_features)
        expected_output = module(**input)

        self.assertEqual(lazy_output, expected_output)

    def test_lazy_batchnorm1d(self):
        self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
        self._check_lazy_norm(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))

    def test_lazy_batchnorm1d_pickle(self):
        self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6))
        self._check_lazy_norm_pickle(nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 6))

    def test_lazy_batchnorm1d_state(self):
        self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)
        self._check_lazy_batchnorm_state(nn.BatchNorm1d, nn.LazyBatchNorm1d)

    def test_lazy_batchnorm2d(self):
        self._check_lazy_norm(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))

    def test_lazy_batchnorm2d_pickle(self):
        self._check_lazy_norm_pickle(nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7))

    def test_lazy_batchnorm2d_state(self):
        self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)
        self._check_lazy_batchnorm_state(nn.BatchNorm2d, nn.LazyBatchNorm2d)

    def test_lazy_batchnorm3d(self):
        self._check_lazy_norm(nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8))

    def test_lazy_batchnorm3d_pickle(self):
        self._check_lazy_norm_pickle(
            nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
        )

    def test_lazy_batchnorm3d_state(self):
        self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)
        self._check_lazy_batchnorm_state(nn.BatchNorm3d, nn.LazyBatchNorm3d)

    def test_lazy_instancenorm1d(self):
        self._check_lazy_norm(nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6))

    def test_lazy_instancenorm1d_pickle(self):
        self._check_lazy_norm_pickle(
            nn.InstanceNorm1d, nn.LazyInstanceNorm1d, (16, 3, 6)
        )

    def test_lazy_instancenorm1d_state(self):
        self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)
        self._check_lazy_instancenorm_state(nn.InstanceNorm1d, nn.LazyInstanceNorm1d)

    def test_lazy_instancenorm2d(self):
        self._check_lazy_norm(nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7))

    def test_lazy_instancenorm2d_pickle(self):
        self._check_lazy_norm_pickle(
            nn.InstanceNorm2d, nn.LazyInstanceNorm2d, (16, 3, 6, 7)
        )

    def test_lazy_instancenorm2d_state(self):
        self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)
        self._check_lazy_instancenorm_state(nn.InstanceNorm2d, nn.LazyInstanceNorm2d)

    def test_lazy_instancenorm3d(self):
        self._check_lazy_norm(
            nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
        )

    def test_lazy_instancenorm3d_pickle(self):
        self._check_lazy_norm_pickle(
            nn.InstanceNorm3d, nn.LazyInstanceNorm3d, (16, 3, 6, 7, 8)
        )

    def test_lazy_instancenorm3d_state(self):
        self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)
        self._check_lazy_instancenorm_state(nn.InstanceNorm3d, nn.LazyInstanceNorm3d)

    def test_lazy_batchnorm_with_dict_input(self):
        self._check_lazy_norm_with_dict_input(
            nn.BatchNorm1d, nn.LazyBatchNorm1d, (16, 3, 6)
        )
        self._check_lazy_norm_with_dict_input(
            nn.BatchNorm2d, nn.LazyBatchNorm2d, (16, 3, 6, 7)
        )
        self._check_lazy_norm_with_dict_input(
            nn.BatchNorm3d, nn.LazyBatchNorm3d, (16, 3, 6, 7, 8)
        )

    @suppress_warnings
    def test_materialize_dtype(self):
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        module.test_param.materialize(10)
        self.assertTrue(module.test_param.dtype == torch.get_default_dtype())
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        module.half()
        module.test_param.materialize(10)
        self.assertTrue(module.test_param.dtype == torch.float16)

    @unittest.skipIf(
        not (TEST_CUDA or TEST_PRIVATEUSE1), "CUDA and PRIVATEUSE1 not available"
    )
    @suppress_warnings
    def test_materialize_device(self):
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        module.test_param.materialize(10)
        self.assertTrue(module.test_param.device.type == "cpu")
        if TEST_CUDA:
            device = "cuda"
        elif TEST_PRIVATEUSE1:
            device = torch._C._get_privateuse1_backend_name()
        module = LazyModule()
        module.register_parameter("test_param", UninitializedParameter())
        module.to(device)
        module.test_param.materialize(10)
        self.assertTrue(module.test_param.device.type == device)

    @suppress_warnings
    def test_chained_initialization(self):
        class MyNetwork(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.linear_1 = torch.nn.LazyLinear(15)
                self.linear_2 = torch.nn.LazyLinear(10)

            def forward(self, x):
                y = self.linear_1(x)
                return self.linear_2(y)

        net = MyNetwork()
        net(torch.ones(5, 10))
        self.assertTrue(net.linear_1.weight.shape == (15, 10))
        self.assertTrue(net.linear_1.bias.shape == (15,))
        self.assertTrue(net.linear_2.weight.shape == (10, 15))
        self.assertTrue(net.linear_2.bias.shape == (10,))

    @suppress_warnings
    def test_optimizer_pass(self):
        optimizers = [
            torch.optim.Adadelta,
            torch.optim.Adagrad,
            torch.optim.Adamax,
            torch.optim.Adam,
            torch.optim.AdamW,
            torch.optim.ASGD,
            torch.optim.SGD,
            torch.optim.Rprop,
            torch.optim.RMSprop,
            torch.optim.LBFGS,
            torch.optim.NAdam,
            torch.optim.RAdam,
        ]

        def run_step(module, optim):
            self.assertIsInstance(
                optim.param_groups[0]["params"][0], UninitializedParameter
            )
            module.test_param.materialize(10)
            self.assertIsInstance(optim.param_groups[0]["params"][0], Parameter)
            self.assertNotIsInstance(
                optim.param_groups[0]["params"][0], UninitializedParameter
            )
            for p in module.parameters():
                p.grad = torch.rand_like(p)
            if isinstance(optim, torch.optim.LBFGS):
                optim.step(lambda: 1.0)
            else:
                optim.step()

        for optim_cls in optimizers:
            module = LazyModule()
            module.register_parameter("test_param", UninitializedParameter())
            if optim_cls is torch.optim.SGD:
                optim = optim_cls(module.parameters(), lr=0.0)
            elif optim_cls is torch.optim.Adagrad:
                with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
                    optim = optim_cls(module.parameters())
                continue
            else:
                optim = optim_cls(module.parameters())
            run_step(module, optim)

    @suppress_warnings
    def test_weight_norm(self):
        m = nn.LazyLinear(7)
        with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
            m = torch.nn.utils.weight_norm(m)

    @suppress_warnings
    def test_spectral_norm(self):
        m = nn.LazyLinear(7)
        with self.assertRaisesRegex(ValueError, "have uninitialized parameters."):
            m = torch.nn.utils.spectral_norm(m)

    @suppress_warnings
    def test_invalid_functions(self):
        param = torch.nn.parameter.UninitializedParameter()
        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
            torch.empty_like(param)

        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
            torch.add(param, param)

        with self.assertRaisesRegex(ValueError, "uninitialized parameter"):
            param + param


if __name__ == "__main__":
    run_tests()
