# Owner(s): ["module: nn"]
import pickle
from copy import deepcopy
from itertools import product

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.parametrize as parametrize
from torch import Tensor
from torch.__future__ import get_swap_module_params_on_conversion
from torch.nn import Buffer, Parameter
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
    gradcheck,
    instantiate_parametrized_tests,
    run_tests,
    set_default_dtype,
    skipIfNoLapack,
    skipIfTorchDynamo,
    swap,
    TemporaryFileName,
)
from torch.testing._internal.two_tensor import TwoTensor


class TestNNParametrization(NNTestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True

    # FIXME: Rewrite this test using functions not depending on LAPACK
    #        and remove the `@skipIfNoLapack` (see #70995)
    # torch/nn/utils/parametrize
    @skipIfNoLapack
    @swap([True, False])
    def test_register_and_remove_parametrization(self):
        r"""Test that it is possible to add a few parametrizations
        on a parameter or a buffer and that removing them restores the initial state
        It also tests that backpropagating through them works as expected
        """

        # Define a couple matrix parametrizations
        class Skew(nn.Module):
            def forward(self, X):
                X = X.tril(-1)
                return X - X.T

        class Orthogonal(nn.Module):
            def forward(self, X):
                # Cayley map
                # If X is skew-symmetric it returns an orthogonal matrix
                Id = torch.eye(X.size(0), device=X.device)
                # We call contiguous because solve returns a tensor with strides that are Fortran-contiguous
                # and autograd raises a performance warning.
                # This happens when we remove the parametrization with leave_parametrized=True,
                # which does a set_ with a non-contiguous tensor while the gradient is contiguous
                return torch.linalg.solve(Id + X, Id - X).contiguous()

        class Resize(nn.Module):
            def forward(self, X):
                return X[[0]]

        class NoResize(nn.Module):
            def forward(self, X):
                return X

        # Define a couple vector parametrizations
        class FirstZero(nn.Module):
            def forward(self, x):
                return torch.cat([x.new_zeros(1), x[1:]])

        class LastZero(nn.Module):
            def forward(self, x):
                return torch.cat([x[:-1], x.new_zeros(1)])

        model = nn.Linear(8, 8)
        initial_weight_id = id(model.weight)
        initial_bias_id = id(model.bias)
        initial_model = deepcopy(model)

        # Test unsafe flag
        with self.assertRaisesRegex(
            ValueError,
            "Registering a parametrization may not change the shape of the tensor",
        ):
            parametrize.register_parametrization(
                model, "weight", Resize()
            )  # default unsafe = False
            model(torch.ones(8, 8))

        # One parametrization with unsafe=True
        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertNotIn("weight", model._parameters)
        self.assertTrue(model.weight.shape[0] == 1)
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.weight, initial_model.weight)
        self.assertEqual(id(model.weight), initial_weight_id)
        self.assertEqual(model.__class__, nn.Linear)

        # Two parametrizations with unsafe=True
        parametrize.register_parametrization(model, "weight", Resize(), unsafe=True)
        parametrize.register_parametrization(model, "weight", NoResize(), unsafe=False)
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertNotIn("weight", model._parameters)
        self.assertTrue(model.weight.shape[0] == 1)
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.weight, initial_model.weight)
        self.assertEqual(id(model.weight), initial_weight_id)
        self.assertEqual(model.__class__, nn.Linear)

        # Test unsafe flag doesn't change expected behavior
        parametrize.register_parametrization(model, "weight", Skew(), unsafe=True)
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertNotIn("weight", model._parameters)
        # Result should be skew-symmetric
        A = model.weight
        self.assertEqual(A, -A.T)
        if get_swap_module_params_on_conversion():
            # When using the swap_tensors path, this is needed so that the autograd
            # graph is not alive anymore.
            del A
        # Remove and check consistency
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.weight, initial_model.weight)
        self.assertEqual(id(model.weight), initial_weight_id)
        self.assertEqual(model.__class__, nn.Linear)

        # Test one parametrization
        parametrize.register_parametrization(model, "weight", Skew())
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertNotIn("weight", model._parameters)
        # Result should be skew-symmetric
        A = model.weight
        self.assertEqual(A, -A.T)
        if get_swap_module_params_on_conversion():
            # When using the swap_tensors path, this is needed so that the autograd
            # graph is not alive anymore.
            del A
        # Remove and check consistency
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.weight, initial_model.weight)
        self.assertEqual(id(model.weight), initial_weight_id)
        self.assertEqual(model.__class__, nn.Linear)

        # Test two parametrizations at the same time and removing them
        parametrize.register_parametrization(model, "weight", Skew())
        parametrize.register_parametrization(model, "weight", Orthogonal())
        # Result should be orthogonal
        X = model.weight
        Id = torch.eye(X.size(0), device=X.device)
        self.assertEqual(X.T @ X, Id)
        if get_swap_module_params_on_conversion():
            # When using the swap_tensors path, this is needed so that the autograd
            # graph is not alive anymore.
            del X
        # Structure tests
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertIn("weight", model.parametrizations)
        self.assertNotIn("weight", model._parameters)
        # Remove
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertEqual(model.weight, initial_model.weight)
        self.assertEqual(id(model.weight), initial_weight_id)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.__class__, nn.Linear)

        # Add everything
        parametrize.register_parametrization(model, "weight", Skew())
        parametrize.register_parametrization(model, "weight", Orthogonal())
        parametrize.register_parametrization(model, "bias", FirstZero())
        parametrize.register_parametrization(model, "bias", LastZero())

        # Basic tests
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertTrue(parametrize.is_parametrized(model, "bias"))
        self.assertEqual(model.bias[0].item(), 0.0)
        self.assertEqual(model.bias[-1].item(), 0.0)
        self.assertEqual(
            len(list(model.parameters())), 2
        )  # Nothing weird has happpened
        # Should not throw

        sgd = torch.optim.SGD(model.parameters(), lr=0.01)

        weight_copy = model.weight.clone()
        bias_copy = model.bias.clone()
        sgd.zero_grad()
        (model.weight.T @ model.bias).sum().backward()
        sgd.step()
        self.assertNotEqual(model.weight, weight_copy)
        self.assertNotEqual(model.bias, bias_copy)

        # Remove first parametrization.
        # Check that the model is still parametrized and so is the second parameter
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertTrue(parametrize.is_parametrized(model))  # Still parametrized
        self.assertFalse(
            parametrize.is_parametrized(model, "weight")
        )  # Parametrization removed
        self.assertTrue(
            parametrize.is_parametrized(model, "bias")
        )  # Still parametrized
        self.assertEqual(model.bias[0].item(), 0.0)  # Still parametrized
        self.assertEqual(model.bias[-1].item(), 0.0)  # Still parametrized
        self.assertNotEqual(model.weight, initial_model.weight)  # Has been updated
        self.assertEqual(id(model.weight), initial_weight_id)  # Keeps the same id
        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happened
        # Should not throw
        weight_copy = model.weight.clone()
        bias_copy = model.bias.clone()
        sgd.zero_grad()
        (model.weight.T @ model.bias).sum().backward()
        sgd.step()
        self.assertNotEqual(model.weight, weight_copy)
        self.assertNotEqual(model.bias, bias_copy)

        # Remove the second parametrization.
        # Check that the module is not parametrized
        parametrize.remove_parametrizations(model, "bias", leave_parametrized=False)
        self.assertFalse(parametrize.is_parametrized(model))  # Not parametrized
        self.assertNotEqual(model.bias, initial_model.bias)  # Has been updated
        self.assertNotEqual(model.bias[0].item(), 0.0)  # Not parametrized
        self.assertNotEqual(model.bias[-1].item(), 0.0)  # Not parametrized
        self.assertEqual(id(model.bias), initial_bias_id)  # Keeps the same id
        self.assertFalse(
            hasattr(model, "parametrizations")
        )  # Not parametrized the module
        self.assertEqual(model.__class__, nn.Linear)  # Resores the previous class
        self.assertEqual(len(list(model.parameters())), 2)  # Nothing weird has happeed

        # Should not throw things are updated
        weight_copy = model.weight.clone()
        bias_copy = model.bias.clone()
        sgd.zero_grad()
        (model.weight.T @ model.bias).sum().backward()
        sgd.step()
        self.assertNotEqual(model.weight, weight_copy)
        self.assertNotEqual(model.bias, bias_copy)
        if get_swap_module_params_on_conversion():
            # When using the swap_tensors path, this is needed so that the autograd
            # graph is not alive anymore.
            del weight_copy, bias_copy

        # Test leave_parametrized=True
        for _ in range(2):
            parametrize.register_parametrization(model, "weight", Skew())
            parametrize.register_parametrization(model, "weight", Orthogonal())
            parametrize.remove_parametrizations(
                model, "weight", leave_parametrized=True
            )
            # We didn't change the dtype nor had multiple inputs, so the id should be the same
            self.assertEqual(id(model.weight), initial_weight_id)
            self.assertEqual(id(model.bias), initial_bias_id)

            # Should not throw. Things are updated
            weight_copy = model.weight.clone()
            bias_copy = model.bias.clone()
            sgd.zero_grad()
            (model.weight.T @ model.bias).sum().backward()
            sgd.step()
            self.assertNotEqual(model.weight, weight_copy)
            self.assertNotEqual(model.bias, bias_copy)
            if get_swap_module_params_on_conversion():
                # When using the swap_tensors path, this is needed so that the autograd
                # graph is not alive anymore.
                del weight_copy, bias_copy

    @swap([True, False])
    def test_register_and_remove_nested_parametrization(self):
        r"""Test that it is possible to nest the parametrizations
        meaning that the original param is parametrized again
        """

        class Skew(nn.Module):
            def forward(self, X):
                X = X.tril(-1)
                return X - X.T

        model = nn.Linear(8, 8)
        # Add top level parametrization
        parametrize.register_parametrization(model, "weight", Skew())
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertNotIn("weight", model._parameters)
        # Result should be skew-symmetric
        A = model.weight
        self.assertEqual(A, -A.T)
        if get_swap_module_params_on_conversion():
            # When using the swap_tensors path, this is needed so that the autograd
            # graph is not alive anymore.
            del A

        # Add nested parametrization
        param_mod = model.parametrizations.weight
        self.assertFalse(hasattr(param_mod, "parametrizations"))
        self.assertFalse(parametrize.is_parametrized(param_mod))
        self.assertFalse(parametrize.is_parametrized(param_mod, "original"))

        parametrize.register_parametrization(param_mod, "original", Skew())
        self.assertTrue(hasattr(param_mod, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(param_mod))
        self.assertTrue(parametrize.is_parametrized(param_mod, "original"))
        self.assertNotIn("original", param_mod._parameters)
        # Result should be skew-symmetric
        A = param_mod.original
        self.assertEqual(A, -A.T)

        # Remove nested param and check consistency
        parametrize.remove_parametrizations(
            param_mod, "original", leave_parametrized=False
        )
        self.assertFalse(hasattr(param_mod, "parametrizations"))
        self.assertEqual(param_mod.__class__, parametrize.ParametrizationList)

        # Remove top level and check consistency
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=False)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.__class__, nn.Linear)

    @swap([True, False])
    def test_register_and_remove_buffer_parametrization(self):
        r"""Test that it is possible to add and remove parametrizations on buffers"""

        # Define a couple vector parametrizations
        class FirstZero(nn.Module):
            def forward(self, x):
                return torch.cat([x.new_zeros(1), x[1:]])

        class LastZero(nn.Module):
            def forward(self, x):
                return torch.cat([x[:-1], x.new_zeros(1)])

        model = nn.Linear(8, 8)

        # Instantiate parametrizations on buffers. It should work as expected
        delattr(model, "bias")
        model.bias = Buffer(torch.ones(8))
        parametrize.register_parametrization(model, "bias", FirstZero())
        parametrize.register_parametrization(model, "bias", LastZero())
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "bias"))
        self.assertEqual(model.bias[0].item(), 0.0)
        self.assertEqual(model.bias[-1].item(), 0.0)
        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
        self.assertEqual(len(list(model.parameters())), 1)

        # Remove parametrizations on buffers. It should work as expected
        parametrize.remove_parametrizations(model, "bias", leave_parametrized=True)
        self.assertFalse(parametrize.is_parametrized(model))
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertEqual(model.bias[0].item(), 0.0)
        self.assertEqual(model.bias[-1].item(), 0.0)
        self.assertTrue((model.bias[1:-1] == torch.ones(6)).all())
        self.assertEqual(len(list(model.parameters())), 1)

    # FIXME: Rewrite this test using functions not depending on LAPACK
    #        and remove the `@skipIfNoLapack` (see #70995)
    @skipIfNoLapack
    @swap([True, False])
    def test_serialization_parametrization(self):
        r"""Test that it is possible to serialize a parametrized model via state_dict"""

        # A stateful parametrization
        class Orthogonal(nn.Module):
            def __init__(self, n):
                super().__init__()
                self.id = Buffer(torch.eye(n))
                self.B = Buffer(torch.empty(n, n))
                init.orthogonal_(self.B)

            def forward(self, X):
                A = X.triu(1)
                A = A - A.T
                return self.B @ torch.linalg.solve(self.id + A, self.id - A)

        def get_model():
            model = torch.nn.Sequential(
                torch.nn.Linear(5, 5),
                torch.nn.ReLU(),
                torch.nn.Linear(5, 1),
            )

            parametrize.register_parametrization(model[0], "weight", Orthogonal(5))
            return model

        model = get_model()

        prev_weight = model[0].weight
        prev_B = model[0].parametrizations.weight[0].B

        new_model = get_model()
        with TemporaryFileName() as fname:
            torch.save(model.state_dict(), fname)
            new_model.load_state_dict(torch.load(fname))

        # Integrity tests
        self.assertTrue(parametrize.is_parametrized(new_model[0], "weight"))
        self.assertEqual(prev_weight, new_model[0].weight)
        self.assertEqual(prev_B, new_model[0].parametrizations.weight[0].B)

        # Trying to save the whole parametrized model raises
        with self.assertRaisesRegex(RuntimeError, "state_dict"):
            with TemporaryFileName() as fname:
                torch.save(model, fname)

    # FIXME: Rewrite this test using functions not depending on LAPACK
    #        and remove the `@skipIfNoLapack` (see #70995)
    @skipIfNoLapack
    @swap([True, False])
    def test_initialization_parametrization(self):
        r"""Test that it is possible to initialize a parametrization when it
        implements a `right_inverse` method
        """

        class Skew(nn.Module):
            def forward(self, X):
                A = X.triu(1)
                return A - A.T

            def is_skew(self, A):
                return torch.allclose(A, -A.T, atol=1e-6)

            def right_inverse(self, X):
                if not self.is_skew(X):
                    raise ValueError("The matrix is not skew-symmetric.")
                return X.triu(1)

        # Implements a Cayley map where right_inverse is not quite the inverse of forward
        class Orthogonal(nn.Module):
            def __init__(self, n):
                super().__init__()
                self.B = Buffer(torch.eye(n))

            def forward(self, X):
                Id = torch.eye(X.size(0))
                return self.B @ torch.linalg.solve(Id + X, Id - X)

            def is_orthogonal(self, X):
                Id = torch.eye(X.size(0))
                return torch.allclose(X.T @ X, Id, atol=1e-4)

            def right_inverse(self, X):
                if not self.is_orthogonal(X):
                    raise ValueError("The input is not orthogonal.")
                # cayley(0) == Id, so B @ cayley(0) == B
                self.B = X
                return torch.zeros_like(X)

        N = 5
        model = nn.Linear(N, N)
        # Register the skew-symmetric constraint. The result is now skew-symmetric
        skew = Skew()
        # Make the weight skew-symmetric before registering the parametrization
        with torch.no_grad():
            model.weight.set_(skew(model.weight))
        parametrize.register_parametrization(model, "weight", skew)
        X = torch.rand(N, N)
        # X is not skew-symmetric, so it throws an error
        with self.assertRaises(ValueError):
            model.weight = X
        # Make X skew-symmetric
        X = X - X.T
        model.weight = X
        self.assertEqual(model.parametrizations.weight.original, X.triu(1))
        self.assertEqual(model.weight, X)

        # Having several parametrizations registered should work in the same way
        parametrize.register_parametrization(model, "weight", Orthogonal(N))
        # Register now the Cayley map. The result is now orthogonal
        X = torch.rand(N, N)
        # X is not orthogonal, so it throws an error
        with self.assertRaises(ValueError):
            model.weight = X
        init.orthogonal_(X)
        model.weight = X
        self.assertEqual(model.weight, X)
        self.assertEqual(model.parametrizations.weight.original, torch.zeros_like(X))

    @swap([True, False])
    def test_errors_unparametrized_tensor_parametrization(self):
        # Test errors when registering a parametrization on an unparametrized tensor
        module = nn.Linear(3, 4)
        weight_init = module.weight.clone()

        class Identity(nn.Module):
            def forward(self, x):
                return x

        # Register a parametrization on a non-existing parameter throws
        with self.assertRaisesRegex(ValueError, "does not have a parameter"):
            parametrize.register_parametrization(module, "foo", Identity())
        self.assertFalse(parametrize.is_parametrized(module))

        # Removing parametrizations from an unparametrized tensor throws
        with self.assertRaisesRegex(ValueError, "does not have a parametrization"):
            parametrize.remove_parametrizations(module, "bias")
        self.assertFalse(parametrize.is_parametrized(module))

        # A correct parametrization with several outputs
        class Sum(nn.Module):
            def forward(self, x, y):
                return x + y

            def right_inverse(self, z):
                return z, torch.zeros_like(z)

        parametrize.register_parametrization(module, "weight", Sum())
        # Cannot remove a parametrization with several outputs with `leave_parametrized=False`
        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
            parametrize.remove_parametrizations(
                module, "weight", leave_parametrized=False
            )
        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)

        # A parametrization with an incorrect number of outputs
        class WrongNumberParams(nn.Module):
            def forward(self, x, y, z):
                return x + y + z

            def right_inverse(self, w):
                return w, torch.zeros_like(w)

        # Makes param(*param.right_inverse(X)) fail
        with self.assertRaisesRegex(TypeError, "positional argument"):
            parametrize.register_parametrization(module, "weight", WrongNumberParams())
        self.assertFalse(parametrize.is_parametrized(module))

        # A parametrization with a right_inverse that does not return a Tensor or Sequence[Tensor]
        class WrongRightInverse(Identity):
            def right_inverse(self, z):
                return None

        # right_inverse should return a Tensor or a Sequence[Tensor]
        with self.assertRaisesRegex(ValueError, "Tensor or a Sequence of"):
            parametrize.register_parametrization(module, "weight", WrongRightInverse())
        self.assertFalse(parametrize.is_parametrized(module))

        # If it's a sequence, it must to be a sequence of tensors
        class WrongRightInverseSequence(nn.Module):
            def forward(self, x, y):
                return x

            def right_inverse(self, z):
                return None, z

        with self.assertRaisesRegex(ValueError, "of the sequence with type"):
            parametrize.register_parametrization(
                module, "weight", WrongRightInverseSequence()
            )
        self.assertFalse(parametrize.is_parametrized(module))

        # A parametrization from one tensor to one tensor that changes the dtype
        class ChangeDtypeInverse(nn.Module):
            def forward(self, x):
                return x.float()

            def right_inverse(self, w):
                return w.bool()

        # For parametrizations that return one tensor, right_inverse may not change the dtype
        with self.assertRaisesRegex(
            ValueError, "outputs one tensor, it may not change the dtype"
        ):
            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
        self.assertFalse(parametrize.is_parametrized(module))

        # Doesn't return a tensor
        class NotTensor(nn.Module):
            def forward(self, x):
                return 2

        # Forward must return a tensor
        with self.assertRaisesRegex(ValueError, "must return a tensor"):
            parametrize.register_parametrization(module, "weight", NotTensor())
        self.assertFalse(parametrize.is_parametrized(module))

        # A parametrization from one tensor to one tensor that changes the dtype
        class ChangeDtype(nn.Module):
            def forward(self, x):
                return x.bool()

        # forward should not change the initial dtype
        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
            parametrize.register_parametrization(module, "weight", ChangeDtype())
        self.assertFalse(parametrize.is_parametrized(module))

        # Change shape
        class ChangeShape(nn.Module):
            def forward(self, x):
                return x[:-1]

        # forward should not change the original shape
        with self.assertRaisesRegex(ValueError, "may not change the shape"):
            parametrize.register_parametrization(module, "weight", ChangeShape())
        self.assertFalse(parametrize.is_parametrized(module))

        # Many to one that changes dtype
        class ChangeDtypeMulti(nn.Module):
            def forward(self, x, y):
                return (x + y).bool()

            def right_inverse(self, w):
                return w, w + 1

        # forward should not change the original shape even for parametrizations with many inputs
        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
            parametrize.register_parametrization(module, "weight", ChangeDtypeMulti())
        self.assertFalse(parametrize.is_parametrized(module))

        # Returning a sequence of size one, although weird, it's correct
        class SequenceLen1(nn.Module):
            def forward(self, x):
                return x

            def right_inverse(self, w):
                return (w,)

        parametrize.register_parametrization(module, "weight", SequenceLen1())
        self.assertTrue(hasattr(module.parametrizations.weight, "original0"))
        self.assertFalse(hasattr(module.parametrizations.weight, "original1"))
        _ = module.weight  # Does not throw
        self.assertTrue(parametrize.is_parametrized(module))
        parametrize.remove_parametrizations(module, "weight", leave_parametrized=True)

        # None of the operations above should have altered the weight
        self.assertFalse(parametrize.is_parametrized(module))
        self.assertEqual(module.weight, weight_init)

    @swap([True, False])
    def test_errors_parametrized_tensor_parametrization(self):
        # Test errors when registering a parametrization on a parametrized tensor

        class Identity(nn.Module):
            def forward(self, x):
                return x

        module = nn.Linear(3, 4)
        parametrize.register_parametrization(module, "weight", Identity())

        # Has to return a tensor
        class WrongReturn(nn.Module):
            def forward(self, x):
                return x, x

        with self.assertRaisesRegex(ValueError, "must return a tensor"):
            parametrize.register_parametrization(module, "weight", WrongReturn())
        self.assertTrue(parametrize.is_parametrized(module))
        self.assertEqual(len(module.parametrizations.weight), 1)
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))

        # Cannot change dtype
        class ChangeDtype(nn.Module):
            def forward(self, x):
                return x.bool()

        with self.assertRaisesRegex(ValueError, "may not change the dtype"):
            parametrize.register_parametrization(module, "weight", ChangeDtype())
        self.assertTrue(parametrize.is_parametrized(module))
        self.assertEqual(len(module.parametrizations.weight), 1)
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))

        # Cannot change shape
        class ChangeShape(nn.Module):
            def forward(self, x):
                return x[:-1]

        with self.assertRaisesRegex(ValueError, "may not change the shape"):
            parametrize.register_parametrization(module, "weight", ChangeShape())
        self.assertTrue(parametrize.is_parametrized(module))
        self.assertEqual(len(module.parametrizations.weight), 1)
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))

        # The following checks are mostly due to bugs in the code of the parametrization

        # right_inverse has to return a tensor
        class WrongReturnInverse(Identity):
            def right_inverse(self, x):
                return x, x

        with self.assertRaisesRegex(ValueError, "right_inverse must return a tensor"):
            parametrize.register_parametrization(module, "weight", WrongReturnInverse())
        self.assertTrue(parametrize.is_parametrized(module))
        self.assertEqual(len(module.parametrizations.weight), 1)
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))

        # Cannot change dtype
        class ChangeDtypeInverse(Identity):
            def right_inverse(self, x):
                return x.bool()

        with self.assertRaisesRegex(ValueError, "must have the same dtype"):
            parametrize.register_parametrization(module, "weight", ChangeDtypeInverse())
        self.assertTrue(parametrize.is_parametrized(module))
        self.assertEqual(len(module.parametrizations.weight), 1)
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))

        # Cannot change shape
        class ChangeShapeInverse(Identity):
            def right_inverse(self, x):
                return x[:-1]

        with self.assertRaisesRegex(ValueError, "must have the same shape"):
            parametrize.register_parametrization(module, "weight", ChangeShapeInverse())
        self.assertTrue(parametrize.is_parametrized(module))
        self.assertEqual(len(module.parametrizations.weight), 1)
        self.assertTrue(isinstance(module.parametrizations.weight[0], Identity))

    # FIXME: Rewrite this test using functions not depending on LAPACK
    #        and remove the `@skipIfNoLapack` (see #70995)
    @skipIfNoLapack
    @swap([True, False])
    def test_multiple_inputs_parametrization(self):
        # A parametrization with several outputs
        class RankOne(nn.Module):
            def forward(self, x, y):
                # Form a rank-1 matrix from a pair of vectors
                return x.unsqueeze(-1) @ y.unsqueeze(-2)

            def right_inverse(self, Y):
                # We project the given matrix onto the rank 1 matrices
                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
                # S is ordered in a decreasing way.
                s0_sqrt = S[0].sqrt().unsqueeze(-1)
                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt

        # Simple parametrisation
        class Double(nn.Module):
            def forward(self, x):
                return 2.0 * x

            def right_inverse(self, w):
                return 0.5 * w

        model = nn.Linear(3, 3)
        # Test one parametrization
        parametrize.register_parametrization(model, "weight", RankOne())
        self.assertTrue(hasattr(model, "parametrizations"))
        self.assertTrue(parametrize.is_parametrized(model))
        self.assertTrue(parametrize.is_parametrized(model, "weight"))
        self.assertTrue(hasattr(model.parametrizations.weight, "original0"))
        self.assertIn("original0", model.parametrizations.weight._parameters)
        self.assertTrue(hasattr(model.parametrizations.weight, "original1"))
        self.assertIn("original1", model.parametrizations.weight._parameters)
        self.assertFalse(parametrize.is_parametrized(model, "bias"))
        self.assertNotIn("weight", model._parameters)
        # Result should be rank 1
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)

        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
            parametrize.remove_parametrizations(
                model, "weight", leave_parametrized=False
            )
        # Remove parametrization and check consistency
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.__class__, nn.Linear)
        self.assertFalse(parametrize.is_parametrized(model))
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
        self.assertIn("weight", model._parameters)

        # Registering parametrizations with one input on top of one with multiple inputs should work
        init_weight = model.weight.clone()
        parametrize.register_parametrization(model, "weight", RankOne())
        # Projecting a rank 1 matrix onto the matrices of rank one does not change the matrix
        self.assertEqual(init_weight, model.weight)
        parametrize.register_parametrization(model, "weight", Double())
        # The matrix now is twice the initial matrix
        self.assertEqual(2.0 * init_weight, model.weight)
        # Multiplying by a scalar does not change the rank
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)

        # The model has now three parameters
        self.assertEqual(len(list(model.parameters())), 3)

        sgd = torch.optim.SGD(model.parameters(), lr=0.1)

        # Test backward. Should not throw
        for _ in range(2):
            sgd.zero_grad()
            loss = (model.weight.T @ model.bias).sum()
            loss.backward()
            sgd.step()

        # Same drill as before, removing should work as expected
        with self.assertRaisesRegex(ValueError, "leave_parametrized=False"):
            # Cannot remove a parametrization with multiple inputs and not leave it parametrized
            parametrize.remove_parametrizations(
                model, "weight", leave_parametrized=False
            )
        # Remove parametrization and check consistency
        parametrize.remove_parametrizations(model, "weight", leave_parametrized=True)
        self.assertFalse(hasattr(model, "parametrizations"))
        self.assertEqual(model.__class__, nn.Linear)
        self.assertFalse(parametrize.is_parametrized(model))
        self.assertEqual(torch.linalg.matrix_rank(model.weight).item(), 1)
        self.assertIn("weight", model._parameters)

        # The model has now two parameters
        self.assertEqual(len(list(model.parameters())), 2)

        # Test backward. Should not throw
        sgd = torch.optim.SGD(model.parameters(), lr=0.1)
        for _ in range(2):
            sgd.zero_grad()
            loss = (model.weight.T @ model.bias).sum()
            loss.backward()
            sgd.step()

    # FIXME: Rewrite this test using functions not depending on LAPACK
    #        and remove the `@skipIfNoLapack` (see #70995)
    @skipIfNoLapack
    @swap([True, False])
    def test_caching_parametrization(self):
        r"""Test the caching system of a parametrization"""

        # Define a couple matrix parametrizations
        class Skew(nn.Module):
            def forward(self, X):
                X = X.tril(-1)
                return X - X.T

        class Orthogonal(nn.Module):
            def forward(self, X):
                Id = torch.eye(X.size(0), device=X.device)
                return torch.linalg.solve(Id + X, Id - X)

        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", Skew())
        parametrize.register_parametrization(model, "weight", Orthogonal())

        # Test that the caching system works
        with parametrize.cached():
            X = model.weight
            Y = model.weight
            self.assertEqual(id(X), id(Y))

    # FIXME: Rewrite this test using functions not depending on LAPACK
    #        and remove the `@skipIfNoLapack` (see #70995)
    @skipIfNoLapack
    @swap([True, False])
    def test_caching_parametrization_with_transfer_parametrizations_and_params(self):
        r"""Test that transferring parametrizations doesn't cause issues with caching"""

        class Skew(nn.Module):
            def forward(self, X):
                X = X.tril(-1)
                return X - X.T

        class Orthogonal(nn.Module):
            def forward(self, X):
                Id = torch.eye(X.size(0), device=X.device)
                return torch.linalg.solve(Id + X, Id - X)

        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", Skew())
        parametrize.register_parametrization(model, "weight", Orthogonal())

        to_model = nn.Linear(5, 5)
        parametrize.transfer_parametrizations_and_params(model, to_model)

        with parametrize.cached():
            X = model.weight
            Y = model.weight
            self.assertEqual(id(X), id(Y))

            A = to_model.weight
            B = to_model.weight
            self.assertEqual(id(A), id(B))

            # test that the results are distinct objects for each module
            self.assertNotEqual(id(A), id(X))

    @swap([True, False])
    def test_parametrization_same_training_mode(self):
        r"""Test training mode updated on parametrization registration"""

        class Identity(nn.Module):
            def forward(self, X):
                return X

        module = nn.Linear(4, 4)
        module.eval()
        parametrize.register_parametrization(module, "weight", Identity())
        self.assertFalse(module.parametrizations.weight[0].training)
        module.train()
        parametrize.register_parametrization(module, "weight", Identity().eval())
        self.assertTrue(module.parametrizations.weight[0].training)
        self.assertTrue(module.parametrizations.weight[1].training)

    @swap([True, False])
    def test_type_before_parametrizations(self):
        r"""Test that type_before_parametrizations always retrieves original type"""

        class Identity(nn.Module):
            def forward(self, X):
                return X

        model = nn.Linear(5, 5)
        original_type = type(model)
        self.assertTrue(
            parametrize.type_before_parametrizations(model) == original_type
        )
        parametrize.register_parametrization(model, "weight", Identity())
        self.assertTrue(
            parametrize.type_before_parametrizations(model) == original_type
        )

    @swap([True, False])
    def test_deepcopy_after_parametrization(self):
        r"""Test that we are able to create a deepcopy of the module when it's parametrized."""

        class AddOne(nn.Module):
            def forward(self, x):
                return x + 1.0

        class ModelWithoutDeepcopy(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.weight = nn.Parameter(
                    torch.tensor([1.0, 1.0, 1.0, 1.0]), requires_grad=True
                )
                self.bias = nn.Parameter(
                    torch.tensor([0.0, 0.0, 0.0, 0.0]), requires_grad=True
                )
                self.attr = [1.0, 2.0, 3.0, 4.0]

        class ActualModel(ModelWithoutDeepcopy):
            # Emulate custom implementation of the deepcopying.
            def __deepcopy__(self, memo):
                result = self.__new__(self.__class__)
                memo[id(self)] = result
                result.__dict__ = deepcopy(self.__dict__, memo)
                return result

        def check_deepcopy(m1: nn.Module, m2: nn.Module):
            w1 = m1.parametrizations.weight.original
            w2 = m2.parametrizations.weight.original
            b1 = (
                m1.parametrizations.bias.original
                if parametrize.is_parametrized(m1, "bias")
                else m1.bias
            )
            b2 = (
                m2.parametrizations.bias.original
                if parametrize.is_parametrized(m2, "bias")
                else m2.bias
            )
            # Weights, biases and attributes should be equal but they must be different objects.
            self.assertEqual(m1.__dict__.keys(), m2.__dict__.keys())
            self.assertIsNot(m1, m2)
            self.assertEqual(w1, w2)
            self.assertIsNot(w1, w2)
            self.assertEqual(b1, b2)
            self.assertIsNot(b1, b2)
            self.assertEqual(m1.attr, m2.attr)
            self.assertIsNot(m1.attr, m2.attr)

        for model in (ModelWithoutDeepcopy(), ActualModel()):
            # General check that we are able to create deepcopy.
            parametrize.register_parametrization(model, "weight", AddOne())
            check_deepcopy(model, deepcopy(model))
            # Check that this works on models with several parametrized tensors.
            parametrize.register_parametrization(model, "bias", AddOne())
            check_deepcopy(model, deepcopy(model))
            # Check that this works on models where tensors have more than one parametrization.
            parametrize.register_parametrization(model, "weight", AddOne())
            check_deepcopy(model, deepcopy(model))

    @swap([True, False])
    def test_transfer_parametrizations_and_params(self):
        r"""Test that all parametrizations and their associated parameters are transferred."""

        class AddOne(nn.Module):
            def forward(self, x):
                return x + 1.0

        class Double(nn.Module):
            def forward(self, x):
                return 2.0 * x

            def right_inverse(self, x):
                return 0.5 * x

        class MinusOne(nn.Module):
            def forward(self, x):
                return x - 1.0

        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", AddOne())
        parametrize.register_parametrization(model, "weight", Double())
        parametrize.register_parametrization(model, "weight", MinusOne())
        hold_weight = model.weight

        to_model = torch.ao.nn.qat.Linear(
            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
        )
        parametrize.transfer_parametrizations_and_params(model, to_model)

        # checks that final and original value are correct and the to_model is parametrized
        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
        self.assertEqual(model.weight, to_model.weight)
        self.assertEqual(
            model.parametrizations.weight.original,
            to_model.parametrizations.weight.original,
        )

        # check that the transfer didn't affect the original value
        self.assertEqual(hold_weight, model.weight)
        if get_swap_module_params_on_conversion():
            # When using the swap_tensors path, this is needed so that the autograd
            # graph is not alive anymore.
            del hold_weight

        # testing that changes to one set of parametrizations do not affect the other
        parametrize.remove_parametrizations(to_model, "weight")
        self.assertFalse(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(model, "weight"))

        # also test that parameters that don't exist in to_model get transferred
        model.test_param = Parameter(torch.randn(5, 5))

        self.assertTrue(not hasattr(to_model, "test_param"))
        parametrize.register_parametrization(model, "test_param", Double())
        hold_test_param = model.test_param
        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")

        # check that previously missing params got transferred correctly
        self.assertEqual(model.test_param, to_model.test_param)
        self.assertEqual(
            model.parametrizations.test_param.original,
            to_model.parametrizations.test_param.original,
        )

        # check that the new transfer didn't change the value for the from_module
        self.assertEqual(hold_test_param, model.test_param)

    @swap([True, False])
    def test_transfer_parametrizations_and_params_right_inverse(self):
        r"""Test that all parametrizations and their associated parameters are transferred."""

        class Double(nn.Module):
            def forward(self, x):
                return 2.0 * x

            def right_inverse(self, x):
                return 0.5 * x

        model = nn.Linear(5, 5)
        parametrize.register_parametrization(model, "weight", Double())
        hold_weight = model.weight

        to_model = torch.ao.nn.qat.Linear(
            5, 5, qconfig=torch.ao.quantization.get_default_qconfig()
        )
        parametrize.transfer_parametrizations_and_params(model, to_model)

        # check that transfer occurs successfully
        self.assertEqual(model.weight, to_model.weight)
        self.assertEqual(
            model.parametrizations.weight.original,
            to_model.parametrizations.weight.original,
        )

        # check that transfer doesn't affect the from_model weight
        self.assertEqual(hold_weight, model.weight)

    @swap([True, False])
    def test_transfer_parametrizations_and_params_single_param(self):
        r"""Test that all parametrizations and their associated parameters are transferred."""

        class AddOne(nn.Module):
            def forward(self, x):
                return x + 1.0

        class Double(nn.Module):
            def forward(self, x):
                return 2.0 * x

        class MinusOne(nn.Module):
            def forward(self, x):
                return x - 1.0

        model = nn.Linear(5, 5, bias=True)
        parametrize.register_parametrization(model, "weight", AddOne())
        parametrize.register_parametrization(model, "weight", Double())
        parametrize.register_parametrization(model, "weight", MinusOne())
        parametrize.register_parametrization(model, "bias", AddOne())
        parametrize.register_parametrization(model, "bias", Double())
        parametrize.register_parametrization(model, "bias", MinusOne())

        to_model = torch.ao.nn.qat.Linear(
            5, 5, bias=True, qconfig=torch.ao.quantization.get_default_qconfig()
        )
        parametrize.transfer_parametrizations_and_params(model, to_model, "weight")

        # check that weight and only weight was transferred
        self.assertEqual(model.weight, to_model.weight)
        self.assertEqual(
            model.parametrizations.weight.original,
            to_model.parametrizations.weight.original,
        )
        self.assertTrue("bias" not in to_model.parametrizations)

    # FIXME: Rewrite this test using functions not depending on LAPACK
    # and remove the `@skipIfNoLapack` (see #70995)
    @skipIfNoLapack
    @swap([True, False])
    def test_transfer_parametrizations_and_params_many_to_one(self):
        # A parametrization with several outputs
        class RankOne(nn.Module):
            def forward(self, x, y):
                # Form a rank-1 matrix from a pair of vectors
                return x.unsqueeze(-1) @ y.unsqueeze(-2)

            def right_inverse(self, Y):
                # We project the given matrix onto the rank 1 matrices
                U, S, Vh = torch.linalg.svd(Y, full_matrices=False)
                # S is ordered in a decreasing way.
                s0_sqrt = S[0].sqrt().unsqueeze(-1)
                return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt

        class Double(nn.Module):
            def forward(self, x):
                return 2.0 * x

        model = nn.Linear(3, 3)
        parametrize.register_parametrization(model, "weight", RankOne())
        parametrize.register_parametrization(model, "weight", Double())
        hold_weight = model.weight

        to_model = torch.ao.nn.qat.Linear(
            3, 3, qconfig=torch.ao.quantization.get_default_qconfig()
        )

        parametrize.transfer_parametrizations_and_params(model, to_model)

        # checks that final and original value are correct and the to_model is parametrized
        self.assertTrue(torch.nn.utils.parametrize.is_parametrized(to_model, "weight"))
        self.assertEqual(model.weight, to_model.weight)
        self.assertEqual(
            model.parametrizations.weight.original0,
            to_model.parametrizations.weight.original0,
        )
        self.assertEqual(
            model.parametrizations.weight.original1,
            to_model.parametrizations.weight.original1,
        )

        # check that the transfer didn't affect the original value
        self.assertEqual(hold_weight, model.weight)

        # testing that changes to one set of parametrizations do not affect the other
        model.test_param = Parameter(torch.randn(3, 3))

        self.assertTrue(not hasattr(to_model, "test_param"))
        parametrize.register_parametrization(model, "test_param", RankOne())
        hold_test_param = model.test_param
        parametrize.transfer_parametrizations_and_params(model, to_model, "test_param")

        # also check that previously missing params got transferred correctly
        self.assertEqual(model.test_param, to_model.test_param)
        self.assertEqual(
            model.parametrizations.test_param.original0,
            to_model.parametrizations.test_param.original0,
        )
        self.assertEqual(
            model.parametrizations.test_param.original1,
            to_model.parametrizations.test_param.original1,
        )

        # check that the new transfer didn't change the value for the from_module
        self.assertEqual(hold_test_param, model.test_param)

    @swap([True, False])
    def test_new_spectral_norm(self):
        with set_default_dtype(torch.double):
            input = torch.randn(3, 5)
            m = nn.Linear(5, 7)
            m = torch.nn.utils.parametrizations.spectral_norm(m)
            spectral_norm_m = m.parametrizations.weight[0]

            self.assertEqual(spectral_norm_m._u.size(), torch.Size([m.weight.size(0)]))

            # .parametrizations.weight.original should be trainable
            self.assertTrue(hasattr(m.parametrizations.weight, "original"))
            self.assertTrue("original" in m.parametrizations.weight._parameters)

            # u should be just a reused buffer
            self.assertTrue(hasattr(spectral_norm_m, "_u"))
            self.assertTrue("_u" in spectral_norm_m._buffers)
            self.assertTrue("_v" in spectral_norm_m._buffers)

            # weight should be a plain attribute, not counted as a buffer or a param
            self.assertIsNotNone(m.weight)
            self.assertFalse("weight" in m._buffers)
            self.assertFalse("weight" in m._parameters)

            # it should also be sharing storage as `weight_orig`
            # self.assertEqual(m.parametrizations.weight.original.storage(), m.weight.storage())
            self.assertEqual(m.parametrizations.weight.original.size(), m.weight.size())
            self.assertEqual(
                m.parametrizations.weight.original.stride(), m.weight.stride()
            )

            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")

            # spectral_norm is the only parametrization
            self.assertFalse(hasattr(m, "parametrizations"))
            self.assertTrue("weight" in m._parameters)

            # We can register spectral_norm multiple times on the same parameter
            # and on multiple parameters in the same module
            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
            m = torch.nn.utils.parametrizations.spectral_norm(m, "weight")
            m = torch.nn.utils.parametrizations.spectral_norm(m, "bias")

            # If we remove the parametrization on bias, weight is still parametrized
            # Removing a parametrization runs forward in eval mode if leave_parametrized=True
            m = torch.nn.utils.parametrize.remove_parametrizations(m, "bias")
            self.assertTrue("bias" in m._parameters)
            self.assertTrue(hasattr(m, "parametrizations"))
            self.assertFalse("weight" in m._parameters)

            m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
            # Neither weight and bias are parametrized
            self.assertFalse(hasattr(m, "parametrizations"))
            self.assertTrue("weight" in m._parameters)
            self.assertFalse(torch.nn.utils.parametrize.is_parametrized(m))

            # test correctness in training/eval modes and cpu/multi-gpu settings
            for apply_dp in (True, False):
                if apply_dp:
                    if not TEST_MULTIGPU:
                        continue
                    device = torch.device("cuda:0")

                    def maybe_wrap(m):
                        return torch.nn.DataParallel(m, [0, 1])

                else:
                    device = torch.device("cpu")

                    def maybe_wrap(m):
                        return m

                for requires_grad in (True, False):

                    def get_modules():
                        m = nn.Linear(3, 4).to(device)
                        m.weight.requires_grad_(requires_grad)
                        m = torch.nn.utils.parametrizations.spectral_norm(m)
                        wrapped_m = maybe_wrap(m)
                        spectral_norm_m = m.parametrizations.weight[0]
                        return m, wrapped_m, spectral_norm_m

                    input = torch.randn(2, 3, device=device)

                    m, wrapped_m, spectral_norm_m = get_modules()

                    self.assertTrue(hasattr(spectral_norm_m, "_u"))
                    u0 = spectral_norm_m._u.clone()
                    v0 = spectral_norm_m._v.clone()

                    # TEST TRAINING BEHAVIOR

                    # We perform GD first to modify the initial matrix
                    opt = torch.optim.SGD(wrapped_m.parameters(), lr=0.1)

                    opt.zero_grad()
                    wrapped_m(input).sum().backward()
                    opt.step()

                    out = wrapped_m(input)
                    if requires_grad:
                        # run forward again and assert that u and v are updated
                        self.assertNotEqual(u0, spectral_norm_m._u)
                        self.assertNotEqual(v0, spectral_norm_m._v)

                    # assert that backprop reaches original weight
                    # can't use gradcheck because the function changes as we
                    # activate through it in training mode
                    if requires_grad:
                        torch.autograd.grad(
                            out.sum(), m.parametrizations.weight.original
                        )

                    # test backward works with multiple forwards
                    # it uses training mode so we need to reset `u` and `v` vectors
                    # to same value at beginning for finite difference test to pass
                    saved_u = spectral_norm_m._u.clone()
                    saved_v = spectral_norm_m._v.clone()

                    def fn(input):
                        spectral_norm_m._u.data.copy_(saved_u)
                        spectral_norm_m._v.data.copy_(saved_v)
                        out0 = wrapped_m(input)
                        out1 = wrapped_m(input)
                        return out0 + out1

                    # Make sure we can compute gradients wrt to all the parameters in the case
                    # of double forward
                    fn(input.clone().requires_grad_()).sum().backward()
                    gradcheck(
                        fn, (input.clone().requires_grad_(),), check_batched_grad=False
                    )

                    # test removing
                    # spectral norm module needs to be in eval mode if we'd like to
                    # avoid doing another power iteration
                    m, wrapped_m, _ = get_modules()
                    pre_remove_out = wrapped_m(input)
                    if get_swap_module_params_on_conversion():
                        # When using the swap_tensors path, this is needed so that the autograd
                        # graph is not alive anymore.
                        pre_remove_out_ref = pre_remove_out.detach()
                        del pre_remove_out
                    else:
                        pre_remove_out_ref = pre_remove_out
                    m.eval()
                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)

                    torch.nn.utils.parametrizations.spectral_norm(m)
                    for _ in range(3):
                        pre_remove_out = wrapped_m(input)
                    if get_swap_module_params_on_conversion():
                        # When using the swap_tensors path, this is needed so that the autograd
                        # graph is not alive anymore.
                        pre_remove_out_ref = pre_remove_out.detach()
                        del pre_remove_out
                    else:
                        pre_remove_out_ref = pre_remove_out
                    m.eval()
                    m = torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
                    self.assertEqual(wrapped_m(input), pre_remove_out_ref)

                    # TEST EVAL BEHAVIOR
                    m, wrapped_m, spectral_norm_m = get_modules()
                    wrapped_m(input)
                    last_train_out = wrapped_m(input)
                    last_train_u = spectral_norm_m._u.clone()
                    last_train_v = spectral_norm_m._v.clone()
                    wrapped_m.zero_grad()
                    wrapped_m.eval()

                    eval_out0 = wrapped_m(input)
                    # assert eval gives same result as last training iteration
                    self.assertEqual(eval_out0, last_train_out)
                    # assert doing more iteartion in eval don't change things
                    self.assertEqual(eval_out0, wrapped_m(input))
                    self.assertEqual(last_train_u, spectral_norm_m._u)
                    self.assertEqual(last_train_v, spectral_norm_m._v)

                    # FIXME: the code below is flaky when executed with DataParallel
                    # see https://github.com/pytorch/pytorch/issues/13818
                    if apply_dp:
                        continue

                    # test backward works with multiple forwards in mixed training
                    # and eval modes
                    # it uses training mode so we need to reset `u` and `v` vectors
                    # to same value at beginning for finite difference test to pass
                    saved_u = spectral_norm_m._u.clone()
                    saved_v = spectral_norm_m._v.clone()

                    def fn(input):
                        spectral_norm_m._u.data.copy_(saved_u)
                        spectral_norm_m._v.data.copy_(saved_v)
                        wrapped_m.train()
                        out0 = wrapped_m(input)
                        wrapped_m.eval()
                        out1 = wrapped_m(input)
                        wrapped_m.train()
                        out2 = wrapped_m(input)
                        wrapped_m.eval()
                        out3 = wrapped_m(input)
                        return out0 + out1 + out2 + out3

                    gradcheck(fn, (input.clone().requires_grad_(),))

                    # assert that backprop reaches weight_orig in eval
                    if requires_grad:

                        def fn(weight):
                            return wrapped_m(input)

                        gradcheck(fn, (m.parametrizations.weight.original,))

    def test_register_parametrization_no_grad(self):
        r"""Test that it is possible to register a parametrization without gradient"""

        class SplitAndCat(nn.Module):
            def right_inverse(self, x):
                # split the tensor in two halfs
                return torch.split(x, x.shape[1] // 2)

            def forward(self, x0, x1):
                return torch.cat([x0, x1])

        model = nn.Linear(8, 8)

        model.weight.requires_grad = False
        parametrize.register_parametrization(model, "weight", SplitAndCat())
        # making sure the parameterized and decomposed Tensors both have requires_grad == False
        self.assertFalse(model.weight.requires_grad)
        self.assertFalse(model.parametrizations.weight.original0.requires_grad)
        self.assertFalse(model.parametrizations.weight.original1.requires_grad)

    @swap([True, False])
    def test_new_spectral_norm_load_state_dict(self):
        for activate_times in (0, 3):
            inp = torch.randn(2, 3)
            m = nn.Linear(3, 5)
            snm = torch.nn.utils.parametrizations.spectral_norm(m)
            snm.train()

            for _ in range(activate_times):
                snm(inp)

            state_dict = deepcopy(snm.state_dict())
            self.assertEqual(
                {
                    "parametrizations.weight.original",
                    "bias",
                    "parametrizations.weight.0._v",
                    "parametrizations.weight.0._u",
                },
                set(state_dict.keys()),
            )

            # test that non-strict loading works
            non_strict_state_dict = deepcopy(state_dict)
            non_strict_state_dict["nonsense"] = "nonsense"
            with self.assertRaisesRegex(
                RuntimeError, r'Unexpected key\(s\) in state_dict: "nonsense"'
            ):
                snm.load_state_dict(non_strict_state_dict, strict=True)
            snm.load_state_dict(non_strict_state_dict, strict=False)
            del non_strict_state_dict["parametrizations.weight.original"]
            snm.load_state_dict(non_strict_state_dict, strict=False)
            del non_strict_state_dict["parametrizations.weight.0._u"]
            snm.load_state_dict(non_strict_state_dict, strict=False)
            del non_strict_state_dict["parametrizations.weight.0._v"]
            snm.load_state_dict(non_strict_state_dict, strict=False)
            non_strict_state_dict[
                "weight"
            ] = snm.weight.detach().clone()  # set W as a buffer
            snm.load_state_dict(non_strict_state_dict, strict=False)
            del non_strict_state_dict._metadata[
                "parametrizations.weight.0"
            ]  # remove metadata info
            snm.load_state_dict(non_strict_state_dict, strict=False)
            del non_strict_state_dict["weight"]  # remove W buffer
            snm.load_state_dict(non_strict_state_dict, strict=False)
            del non_strict_state_dict["bias"]
            snm.load_state_dict(non_strict_state_dict, strict=False)

            # normal state_dict

            # test that re-wrapping does not matter
            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
            snm = torch.nn.utils.parametrizations.spectral_norm(m)

            snm.load_state_dict(state_dict)
            with torch.no_grad():
                snm.eval()
                out0_eval = snm(inp)
                snm.train()
                out1_train = snm(inp)
                out2_train = snm(inp)
                snm.eval()
                out3_eval = snm(inp)

            # test that re-wrapping does not matter
            m = torch.nn.utils.parametrize.remove_parametrizations(snm, "weight")
            snm = torch.nn.utils.parametrizations.spectral_norm(m)

            # Test normal loading
            snm.load_state_dict(state_dict)
            with torch.no_grad():
                snm.eval()
                self.assertEqual(out0_eval, snm(inp))
                snm.train()
                self.assertEqual(out1_train, snm(inp))
                self.assertEqual(out2_train, snm(inp))
                snm.eval()
                self.assertEqual(out3_eval, snm(inp))

    @swap([True, False])
    def test_new_spectral_norm_dim(self):
        inp = torch.randn(2, 3, 10, 12)
        m = nn.ConvTranspose2d(3, 4, (5, 6))
        m = torch.nn.utils.parametrizations.spectral_norm(m)
        snm = m.parametrizations.weight[0]
        # this should not run into incompatible shapes
        x = m(inp)
        # check that u refers to the same dimension
        self.assertEqual(
            snm._u.shape, m.parametrizations.weight.original[0, :, 0, 0].shape
        )

    @swap([True, False])
    def test_new_spectral_norm_forward(self):
        input = torch.randn(3, 5)
        m = nn.Linear(5, 7)
        m = torch.nn.utils.parametrizations.spectral_norm(m)
        snm = m.parametrizations.weight[0]
        # naive forward
        _weight = m.parametrizations.weight.original
        _bias, _v = m.bias, snm._v
        _weight_mat = _weight.view(_weight.size(0), -1)
        _u = torch.mv(_weight_mat, _v)
        _u = F.normalize(_u, dim=0, eps=1e-12)
        _v = torch.mv(_weight_mat.t(), _u)
        _v = F.normalize(_v, dim=0, eps=1e-12)
        _weight.data /= torch.dot(_u, torch.matmul(_weight_mat, _v))
        out_hat = torch.nn.functional.linear(input, _weight, _bias)
        expect_out = m(input)
        self.assertEqual(expect_out, out_hat)

    @swap([True, False])
    @skipIfTorchDynamo("Test does not work with TorchDynamo")
    def test_new_spectral_norm_value(self):
        # a test that the spectral norm (= top singular value)
        # is in fact properly calculated, using example of a simple diagonal matrix.
        for dtype in (torch.float, torch.cfloat):
            m = nn.Linear(2, 2, dtype=dtype)
            with torch.no_grad():
                # set weight to be diagonal
                x = torch.diagonal(m.weight)
                m.weight = nn.Parameter(torch.diag(x))
                torch.nn.utils.parametrizations.spectral_norm(m)
                # weights should be rescaled by spectral norm, (i.e., largest diagonal element in norm)
                expected = torch.diag(x / x.abs().max())
                self.assertEqual(m.weight.data, expected)

    @skipIfNoLapack
    @swap([True, False])
    def test_orthogonal_parametrization(self):
        # Orthogonal implements 6 algorithms (3x parametrizations times 2 options of use_trivialization)

        def assert_is_orthogonal(X):
            n, k = X.size(-2), X.size(-1)
            if n < k:
                X = X.mT
                n, k = k, n
            Id = torch.eye(k, dtype=X.dtype, device=X.device).expand(
                *(X.size()[:-2]), k, k
            )
            eps = 10 * n * torch.finfo(X.dtype).eps
            torch.testing.assert_close(X.mH @ X, Id, atol=eps, rtol=0.0)

        def assert_weight_allclose_Q(weight, W):
            # Test that weight is equal to the Q part of the QR decomposition of W
            # (or of its transpose if the matrix is wide)
            wide_matrix = W.size(-2) < W.size(-1)
            if wide_matrix:
                W = W.mT
            Q, R = torch.linalg.qr(W)
            Q *= R.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
            if wide_matrix:
                Q = Q.mT
            torch.testing.assert_close(Q, weight, atol=1e-5, rtol=0.0)

        for shape, dtype, use_linear in product(
            ((4, 4), (5, 3), (3, 5)),  # square/ tall / wide
            (torch.float32, torch.complex64),
            (True, False),
        ):
            # Conv2d does not support complex yet
            if not use_linear:
                continue

            if use_linear:
                input = torch.randn(3, shape[0], dtype=dtype)
            else:
                input = torch.randn(2, 2, shape[0] + 2, shape[1] + 1, dtype=dtype)

            for parametrization, use_trivialization in product(
                ("matrix_exp", "cayley", "householder"), (False, True)
            ):
                # right_inverse for Cayley and matrix_exp not implemented for use_trivialization=False
                # See Note [right_inverse expm cayley]
                can_initialize = use_trivialization or parametrization == "householder"

                # We generate them every time to always start with fresh weights
                if use_linear:
                    m = nn.Linear(*shape, dtype=dtype)
                else:
                    m = nn.Conv2d(2, 3, shape, dtype=dtype)

                # We do not support householder for complex inputs
                # See Note [Householder complex]

                # When using the swap_tensors path, this is needed so that the autograd
                # graph is not alive anymore.
                if get_swap_module_params_on_conversion():
                    w_init = m.weight.clone().detach()
                else:
                    w_init = m.weight.clone()
                if parametrization == "householder" and m.weight.is_complex():
                    msg = "householder parametrization does not support complex tensors"
                    with self.assertRaisesRegex(ValueError, msg):
                        torch.nn.utils.parametrizations.orthogonal(
                            m,
                            "weight",
                            parametrization,
                            use_trivialization=use_trivialization,
                        )
                    continue

                wide_matrix = w_init.size(-2) < w_init.size(-1)
                torch.nn.utils.parametrizations.orthogonal(
                    m, "weight", parametrization, use_trivialization=use_trivialization
                )
                # Forwards works as expected
                self.assertEqual(w_init.shape, m.weight.shape)
                assert_is_orthogonal(m.weight)
                if can_initialize:
                    assert_weight_allclose_Q(m.weight, w_init)

                # Intializing with a given orthogonal matrix works
                X = torch.randn_like(m.weight)
                if wide_matrix:
                    X = X.mT
                w_new = torch.linalg.qr(X).Q
                if wide_matrix:
                    w_new = w_new.mT
                if can_initialize:
                    m.weight = w_new
                    torch.testing.assert_close(w_new, m.weight, atol=1e-5, rtol=0.0)
                else:
                    msg = (
                        "assign to the matrix exponential or the Cayley parametrization"
                    )
                    with self.assertRaisesRegex(NotImplementedError, msg):
                        m.weight = w_new

                # Intializing with a non-orthogonal matrix makes m.weight be the Q part of the given matrix
                w_new = torch.randn_like(m.weight)
                if can_initialize:
                    m.weight = w_new
                    assert_weight_allclose_Q(m.weight, w_new)
                else:
                    msg = (
                        "assign to the matrix exponential or the Cayley parametrization"
                    )
                    with self.assertRaisesRegex(NotImplementedError, msg):
                        m.weight = w_new

                opt = torch.optim.SGD(m.parameters(), lr=0.1)
                for _ in range(2):
                    opt.zero_grad()
                    m(input).norm().backward()
                    grad = m.parametrizations.weight.original.grad
                    self.assertIsNotNone(grad)
                    # We do not update the upper triangular part of the matrix if tall tril if wide
                    if grad.size(-2) >= grad.size(-1):
                        zeros_grad = grad.triu(1)
                    else:
                        zeros_grad = grad.tril(-1)
                    self.assertEqual(zeros_grad, torch.zeros_like(zeros_grad))
                    # The gradient in the diagonal can only be imaginary because a skew-Hermitian
                    # matrix has imaginary diagonal
                    diag_grad = grad.diagonal(dim1=-2, dim2=-1)
                    if grad.is_complex():
                        diag_grad = diag_grad.real
                    self.assertEqual(diag_grad, torch.zeros_like(diag_grad))
                    opt.step()
                    assert_is_orthogonal(m.weight)

    @skipIfNoLapack
    @swap([True, False])
    def test_orthogonal_errors(self):
        m = nn.Linear(3, 4)
        with self.assertRaisesRegex(ValueError, "has to be one of"):
            torch.nn.utils.parametrizations.orthogonal(m, "weight", "foo")

        with self.assertRaisesRegex(ValueError, "Expected a matrix"):
            torch.nn.utils.parametrizations.orthogonal(m, "bias")

        torch.nn.utils.parametrizations.orthogonal(m, "weight")
        with self.assertRaisesRegex(ValueError, "matrices of shape"):
            m.weight = torch.randn(5, 5)
        torch.nn.utils.parametrize.remove_parametrizations(m, "weight")

    @swap([True, False])
    def test_weight_norm_state_dict_compat(self):
        m = nn.Linear(4, 5)
        m = torch.nn.utils.weight_norm(m)
        old_dict = m.state_dict()

        m2 = nn.Linear(4, 5)
        m2 = torch.nn.utils.parametrizations.weight_norm(m2)
        m2.load_state_dict(old_dict)

        input = torch.randn(3, 4)
        self.assertEqual(m(input), m2(input))

    @swap([True, False])
    def test_weight_norm_pickle(self):
        m = nn.Linear(4, 5)
        m = torch.nn.utils.parametrizations.weight_norm(m)
        with self.assertRaisesRegex(RuntimeError, "state_dict"):
            pickle.dumps(m)

    @swap([True, False])
    def test_weight_norm_deepcopy(self):
        m = nn.Linear(4, 5)
        m = torch.nn.utils.parametrizations.weight_norm(m)
        m2 = deepcopy(m)
        input = torch.randn(3, 4)
        self.assertEqual(m(input), m2(input))

    @swap([True])
    def test_wrapper_subclass_parametrization(self):
        class Subclassify(nn.Module):
            def forward(self, X):
                return TwoTensor(X, X)

        class UnSubclassify(nn.Module):
            def forward(self, X):
                return X.a

        class IdentityWithRightInverse(nn.Module):
            def forward(self, X):
                return X

            def right_inverse(self, X):
                return TwoTensor(X, X)

        def _check_parametrization(
            parametrization,
            type_before_registration,
            type_after_registration,
            leave_parametrized=False,
            type_after_right_inverse=None,
        ):
            model = nn.Linear(2, 2)
            buf = torch.randn(2, 2)
            model.buf = torch.nn.Buffer(buf)
            if (
                type_before_registration == TwoTensor
                and type_after_registration == Tensor
            ):
                model._apply(lambda t: TwoTensor(t, t))
            initial_weight = model.weight.clone().detach()
            initial_weight_id = id(model.weight)
            initial_buf = model.buf.clone().detach()
            initial_buf_id = id(model.buf)
            type_original_weight = (
                type_before_registration
                if type_after_right_inverse is None
                else type_after_right_inverse
            )
            type_original_buf = (
                Tensor if type_original_weight is nn.Parameter else type_original_weight
            )
            type_after_removal_buf = (
                type_after_registration if leave_parametrized else type_original_buf
            )
            if leave_parametrized:
                if type_after_registration is Tensor:
                    type_after_removal_weight = nn.Parameter
                else:
                    type_after_removal_weight = type_after_registration
            else:
                type_after_removal_weight = type_original_weight

            parametrize.register_parametrization(model, "weight", parametrization())
            parametrize.register_parametrization(model, "buf", parametrization())
            self.assertTrue(hasattr(model, "parametrizations"))
            self.assertTrue(parametrize.is_parametrized(model))
            self.assertFalse(parametrize.is_parametrized(model, "bias"))
            # checks for weight
            self.assertTrue(parametrize.is_parametrized(model, "weight"))
            self.assertTrue(
                isinstance(model.parametrizations.weight.original, nn.Parameter)
            )
            self.assertTrue(
                type(model.parametrizations.weight.original) is type_original_weight
            )
            self.assertNotIn("weight", model._parameters)
            self.assertTrue(type(model.weight) is type_after_registration)
            # checks for buf
            self.assertTrue(parametrize.is_parametrized(model, "buf"))
            self.assertFalse(
                isinstance(model.parametrizations.buf.original, nn.Parameter)
            )
            self.assertTrue(
                type(model.parametrizations.buf.original) is type_original_buf
            )
            self.assertTrue(type(model.buf) is type_after_registration)
            parametrize.remove_parametrizations(
                model, "weight", leave_parametrized=leave_parametrized
            )
            parametrize.remove_parametrizations(
                model, "buf", leave_parametrized=leave_parametrized
            )
            self.assertFalse(hasattr(model, "parametrizations"))
            self.assertEqual(model.__class__, nn.Linear)
            # checks for weight
            self.assertTrue(type(model.weight) is type_after_removal_weight)
            self.assertTrue(isinstance(model.weight, nn.Parameter))
            self.assertEqual(id(model.weight), initial_weight_id)
            # checks for buf
            self.assertTrue(type(model.buf) is type_after_removal_buf)
            self.assertFalse(isinstance(model.buf, nn.Parameter))
            self.assertEqual(id(model.buf), initial_buf_id)
            if not leave_parametrized and type_after_right_inverse is None:
                self.assertEqual(model.weight, initial_weight)
                self.assertEqual(model.buf, initial_buf)

        _check_parametrization(Subclassify, nn.Parameter, TwoTensor)
        _check_parametrization(UnSubclassify, TwoTensor, Tensor)
        _check_parametrization(
            IdentityWithRightInverse,
            nn.Parameter,
            TwoTensor,
            type_after_right_inverse=TwoTensor,
        )
        _check_parametrization(
            Subclassify, nn.Parameter, TwoTensor, leave_parametrized=True
        )
        _check_parametrization(
            UnSubclassify, TwoTensor, Tensor, leave_parametrized=True
        )
        _check_parametrization(
            IdentityWithRightInverse,
            nn.Parameter,
            TwoTensor,
            leave_parametrized=True,
            type_after_right_inverse=TwoTensor,
        )


class TestNNParametrizationDevice(NNTestCase):
    @swap([True, False])
    def test_weight_norm_parametrization(self, device):
        for dtype in [torch.float, torch.bfloat16]:
            input = torch.randn(3, 4, dtype=dtype, device=device)
            m = nn.Linear(4, 5, dtype=dtype, device=device)
            expected_output = m(input)

            # add weight normalization
            m = torch.nn.utils.parametrizations.weight_norm(m)
            self.assertEqual(
                m.parametrizations.weight.original1.size(), m.weight.size()
            )
            self.assertEqual(m.parametrizations.weight.original0.size(), (5, 1))
            self.assertEqual(m(input), expected_output)

            # remove weight norm
            torch.nn.utils.parametrize.remove_parametrizations(m, "weight")
            self.assertFalse(hasattr(m, "parametrizations"))
            self.assertEqual(m(input), expected_output)

            # test with dim=1
            m = torch.nn.utils.parametrizations.weight_norm(m, dim=1)
            self.assertEqual(
                m.parametrizations.weight.original1.size(), m.weight.size()
            )
            self.assertEqual(m.parametrizations.weight.original0.size(), (1, 4))
            self.assertEqual(m(input), expected_output)

            # test with dim=None
            m = nn.Linear(4, 5, dtype=dtype, device=device)
            expected_output = m(input)
            m = torch.nn.utils.parametrizations.weight_norm(m, dim=None)
            self.assertEqual(m(input), expected_output)


only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestNNParametrizationDevice, globals(), only_for=only_for)
instantiate_parametrized_tests(TestNNParametrization)

if __name__ == "__main__":
    run_tests()
