# Owner(s): ["module: nn"]
import re
import unittest
from copy import deepcopy
from itertools import product

import torch
import torch.nn as nn
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    parametrize,
    run_tests,
    skipIfCrossRef,
    skipIfTorchDynamo,
    swap,
    TEST_NUMPY,
    TestCase,
)
from torch.utils._pytree import tree_map


if TEST_NUMPY:
    import numpy as np


class TestLoadStateDict(NNTestCase):
    _do_cuda_memory_leak_check = True
    _do_cuda_non_default_stream = True

    @unittest.skipIf(not TEST_NUMPY, "numpy not found")
    @swap([True, False])
    def test_load_state_dict_invalid(self):
        m = torch.nn.Linear(2, 2, bias=False)

        state_dict = {"weight": np.random.randn(2, 2)}
        with self.assertRaisesRegex(
            RuntimeError,
            "expected torch.Tensor or Tensor-like object from checkpoint but received",
        ):
            m.load_state_dict(state_dict)

        state_dict = {"weight": ((1.0, 1.0), (2.0, 2.0))}
        with self.assertRaisesRegex(
            RuntimeError,
            "expected torch.Tensor or Tensor-like object from checkpoint but received",
        ):
            m.load_state_dict(state_dict)

    @swap([True, False])
    def test_load_state_dict_type(self):
        m = nn.Module()

        with self.assertRaisesRegex(
            TypeError, "Expected state_dict to be dict-like, got"
        ):
            m.load_state_dict("")
        with self.assertRaisesRegex(
            TypeError, "Expected state_dict to be dict-like, got"
        ):
            m.load_state_dict(2)

    @swap([True, False])
    @skipIfTorchDynamo("dynamo installs weakrefs on some params")
    def test_load_state_dict(self):
        l = nn.Linear(5, 5)
        block = nn.Module()
        block.conv1 = nn.Conv2d(3, 3, 3, bias=True)
        block.conv2 = nn.Conv2d(3, 3, 3, bias=False)
        net = nn.Module()
        net.linear1 = l
        net.linear2 = l
        net.bn = nn.BatchNorm2d(2)
        net.block = block
        net.add_module("empty", None)
        conv1_bias_dtype = block.conv1.bias.dtype

        state_dict = net.state_dict()
        state_dict.update(
            {
                "linear1.weight": torch.ones(5, 5),
                "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
                "bn.running_mean": torch.randn(2),
            }
        )
        # Also test if a DDP state_dict can be loaded from a local model.
        ddp_state_dict = net.state_dict()
        ddp_state_dict.update(
            {
                "module.linear1.weight": torch.ones(5, 5),
                "module.block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
                "module.bn.running_mean": torch.randn(2),
            }
        )
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(
            ddp_state_dict, "module."
        )
        for sd in [state_dict, ddp_state_dict]:
            incompatible_keys = net.load_state_dict(sd)
            self.assertEqual(len(incompatible_keys.missing_keys), 0)
            self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
            self.assertNotIn("Incompatible", str(incompatible_keys))
            self.assertEqual(net.linear1.weight, sd["linear1.weight"])
            self.assertEqual(net.block.conv1.bias, sd["block.conv1.bias"])
            self.assertEqual(net.bn.running_mean, sd["bn.running_mean"])

        state_dict = net.state_dict()
        state_dict.update({"extra": torch.ones(5)})
        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
        incompatible_keys = net.load_state_dict(state_dict, strict=False)
        self.assertEqual(len(incompatible_keys.missing_keys), 0)
        self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
        self.assertIn("extra", incompatible_keys.unexpected_keys)
        self.assertIn("Incompatible", str(incompatible_keys))

        state_dict = net.state_dict()
        state_dict.update({"extra.param": torch.ones(5)})
        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
        incompatible_keys = net.load_state_dict(state_dict, strict=False)
        self.assertEqual(len(incompatible_keys.missing_keys), 0)
        self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
        self.assertIn("extra.param", incompatible_keys.unexpected_keys)

        state_dict = net.state_dict()
        del state_dict["linear1.weight"]
        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
        incompatible_keys = net.load_state_dict(state_dict, strict=False)
        self.assertEqual(len(incompatible_keys.missing_keys), 1)
        self.assertEqual(len(incompatible_keys.unexpected_keys), 0)
        self.assertIn("linear1.weight", incompatible_keys.missing_keys)
        state_dict.update({"extra.param": torch.ones(5)})
        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
        incompatible_keys = net.load_state_dict(state_dict, strict=False)
        self.assertEqual(len(incompatible_keys.missing_keys), 1)
        self.assertEqual(len(incompatible_keys.unexpected_keys), 1)
        self.assertIn("linear1.weight", incompatible_keys.missing_keys)
        self.assertIn("extra.param", incompatible_keys.unexpected_keys)

        state_dict = net.state_dict()
        state_dict.update({"bn.running_mean": torch.rand(14, 4)})  # wrong size
        self.assertRaises(RuntimeError, lambda: net.load_state_dict(state_dict))
        self.assertRaises(
            RuntimeError, lambda: net.load_state_dict(state_dict, strict=False)
        )

        state_dict = net.state_dict()
        old_state_dict = deepcopy(state_dict)
        state_dict = {
            "linear1.weight": torch.ones(5, 5),
            "block.conv1.bias": torch.arange(1, 4, dtype=conv1_bias_dtype),
            "bn.running_mean": torch.randn(2),
            "nonexistent_key": torch.rand(3),
        }
        net.load_state_dict(state_dict, strict=False)
        self.assertEqual(net.linear1.weight, state_dict["linear1.weight"])
        self.assertEqual(net.block.conv1.bias, state_dict["block.conv1.bias"])
        self.assertEqual(net.bn.running_mean, state_dict["bn.running_mean"])
        new_state_dict = net.state_dict()
        del old_state_dict["linear1.weight"]
        del old_state_dict["block.conv1.bias"]
        del old_state_dict["bn.running_mean"]
        for (
            k,
            v,
        ) in old_state_dict.items():
            self.assertTrue(v.equal(new_state_dict[k]))

    @swap([True, False])
    def test_load_state_dict_BC(self):
        # BatchNormNd
        # Added num_batches_tracked buffer at version 2. For state dict with
        # earlier versions or no versions, it should provide default value of 0.
        bn = nn.BatchNorm2d(3)
        state_dict = bn.state_dict()
        del state_dict["num_batches_tracked"]
        state_dict._metadata[""]["version"] = 1  # version 1
        bn.load_state_dict(state_dict)
        self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
        self.assertEqual(bn.num_batches_tracked.item(), 0)
        del state_dict._metadata[""]["version"]  # no version
        bn.load_state_dict(state_dict)
        self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
        self.assertEqual(bn.num_batches_tracked.item(), 0)

    @swap([True, False])
    def test_load_state_dict_child(self):
        base_module = nn.Linear(1, 1)
        model = base_module
        for _ in range(3):
            model = nn.Sequential(*[deepcopy(model) for _ in range(10)])

        def hook_fn(
            module,
            state_dict,
            prefix,
            local_metadata,
            strict,
            missing_keys,
            unexpected_keys,
            error_msgs,
        ):
            module_state_dict = module.state_dict()
            self.assertEqual(len(module_state_dict.keys()), len(state_dict.keys()))

        model[0][0].register_load_state_dict_pre_hook(hook_fn)
        model.load_state_dict(model.state_dict(), strict=True)

    # fails swapping as LSTM installs weak references on the parameters
    @swap([False])
    @skipIfTorchDynamo("TorchDynamo fails here for unknown reasons")
    def test_load_state_dict_ref_cycle(self):
        # load_state_dict shouldn't cause a reference cycle involving Tensors
        import gc

        m = torch.nn.LSTM(16, 16, bidirectional=True)

        gc.collect()
        m.load_state_dict(deepcopy(m).state_dict())
        refcycles = gc.collect()

        self.assertEqual(refcycles, 0)

    @swap([True, False])
    def test_load_state_dict_custom(self):
        class CustomState(nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.param = torch.nn.Parameter(torch.ones(1))
                self.sub = torch.nn.Linear(5, 5)

            def _save_to_state_dict(self, destination, prefix, keep_vars):
                destination[prefix + "serialized"] = self.param.data + 1

            def _load_from_state_dict(
                self,
                state_dict,
                prefix,
                local_metadata,
                strict,
                missing_keys,
                unexpected_keys,
                error_msgs,
            ):
                # skip some of the error handling
                self.param.data.copy_(state_dict[prefix + "serialized"] - 1)

        # use sequential to verify nesting
        m = nn.Sequential(CustomState())
        with torch.no_grad():
            m[0].param[0] = 10
            m[0].sub.weight[0, 0] = 555
        state_dict = m.state_dict()
        self.assertEqual(state_dict["0.serialized"].item(), 11)
        self.assertIn("0.sub.weight", state_dict)
        self.assertNotIn("0.param", state_dict)
        del m
        mm = nn.Sequential(CustomState())
        self.assertEqual(mm[0].param[0].item(), 1)
        mm.load_state_dict(state_dict)
        self.assertEqual(mm[0].param[0].item(), 10)
        self.assertEqual(mm[0].sub.weight[0, 0].item(), 555)

    @swap([True, False])
    @parametrize("keep_vars", [True, False])
    def test_load_state_dict_assign_meta(self, keep_vars):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(3, 5)
                self.bn = nn.BatchNorm1d(5)
                self.x = nn.Parameter(torch.rand(5), requires_grad=False)

            def forward(self, input):
                return self.x + self.bn(self.fc1(input))

        swap = torch.__future__.get_swap_module_params_on_conversion()
        net = MyModule()
        state_dict = net.state_dict(keep_vars=keep_vars)
        for v in state_dict.values():
            v.requires_grad_(False)

        with torch.device("meta"):
            net_meta = MyModule()

        net_meta_state_dict_old = net_meta.state_dict(keep_vars=True)
        net_meta.load_state_dict(state_dict, assign=True)

        # Make sure parameters and persistent buffers were assigned
        net_meta_state_dict = net_meta.state_dict(keep_vars=True)
        for key in state_dict.keys():
            if key in net_meta._parameters:
                if keep_vars and not swap:
                    # state_dict[key] is an nn.Parameter
                    self.assertTrue(state_dict[key] is net_meta_state_dict[key])
                else:
                    if swap:
                        self.assertTrue(
                            net_meta_state_dict[key] is net_meta_state_dict_old[key]
                        )
                    else:
                        # state_dict[key] is not an nn.Parameter so it will be detached when wrapping with a Parameter
                        self.assertTrue(
                            net_meta_state_dict[key] is not net_meta_state_dict_old[key]
                        )
                        self.assertEqual(
                            net_meta_state_dict_old[key].requires_grad,
                            net_meta_state_dict[key].requires_grad,
                        )
                self.assertEqual(
                    net_meta_state_dict_old[key].requires_grad,
                    net_meta_state_dict[key].requires_grad,
                )
                self.assertEqual(state_dict[key], net_meta_state_dict[key])
            elif (
                key in net_meta._buffers
                and key not in net_meta._non_persistent_buffers_set
            ):
                self.assertTrue(state_dict[key] is net_meta_state_dict[key])
                self.assertEqual(state_dict[key], net_meta_state_dict[key])

        # Make sure that ordering of parameters and buffers is preserved
        net_named_parameters = net.named_parameters()
        net_named_buffers = net.named_buffers()
        net_meta_named_parameters = net_meta.named_parameters()
        net_meta_named_buffers = net_meta.named_buffers()

        for (n1, _), (n2, _) in zip(net_named_parameters, net_meta_named_parameters):
            self.assertEqual(n1, n2)

        for (n1, _), (n2, _) in zip(net_named_buffers, net_meta_named_buffers):
            self.assertEqual(n1, n2)

        # Make sure outputs are the same
        t = torch.randn(4, 3)
        out_net = net(t)
        out_net_meta = net_meta(t.clone())

        self.assertEqual(out_net, out_net_meta)

    @swap([True, False])
    def test_load_state_dict_assign_with_optimizer(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(3, 5)
                self.bn = nn.BatchNorm1d(5)

            def forward(self, input):
                return self.bn(self.fc1(input))

        net = MyModule()
        opt = torch.optim.Adam(net.parameters(), lr=1000)
        x = torch.randn(4, 3)
        num_iters = 3

        for i in range(num_iters):
            opt.zero_grad()
            out = net(x)
            out.sum().backward()
            opt.step()

        opt_state_dict = deepcopy(opt.state_dict())
        net_state_dict = deepcopy(net.state_dict())

        with torch.device("meta"):
            net_meta = MyModule()

        net_meta.load_state_dict(net_state_dict, assign=True)
        # must create optimizer only after loading state_dict when assign=True
        opt2 = torch.optim.Adam(net_meta.parameters(), lr=1000)
        opt2.load_state_dict(opt_state_dict)

        y = x.clone()
        for i in range(num_iters):
            opt.zero_grad()
            out = net(x)
            out.sum().backward()
            opt.step()

            opt2.zero_grad()
            out2 = net_meta(y)
            out2.sum().backward()
            opt2.step()

        self.assertEqual(opt.state_dict(), opt2.state_dict())
        self.assertEqual(net.state_dict(), net_meta.state_dict())

    @swap([True, False])
    def test_load_state_dict_assign_shape_stride(self):
        # Assigned tensor is allowed to have different properties than initial
        # tensor except for shape
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = nn.Linear(3, 5)
                self.bn = nn.BatchNorm1d(5)

            def forward(self, input):
                return self.bn(self.fc1(input))

        net = MyModule()
        state_dict = net.state_dict()
        # loading should be ok if stride is different
        state_dict["fc1.weight"] = torch.randn(3, 5).transpose(0, 1)
        net2 = MyModule()
        net2.load_state_dict(state_dict, strict=False, assign=True)

        state_dict["fc1.weight"] = torch.randn(2, 4)
        with self.assertRaisesRegex(
            RuntimeError, "size mismatch for fc1.weight: copying a param with shape"
        ):
            net2.load_state_dict(state_dict, strict=False, assign=True)

    @swap([True, False])
    def test_load_state_dict_warn_assign(self):
        with torch.device("meta"):
            m = torch.nn.Linear(3, 5)
        state_dict = m.state_dict()
        state_dict["weight"] = torch.empty_like(state_dict["weight"], device="cpu")
        with self.assertWarnsRegex(
            UserWarning,
            "for weight: copying from a non-meta parameter in the checkpoint to a meta",
        ):
            m.load_state_dict(state_dict)

    @swap([True, False])
    def test_load_state_dict_with_unexpected_key(self):
        class MyModule(torch.nn.Module):
            def __init__(self) -> None:
                super().__init__()
                self.fc1 = torch.nn.Linear(5, 10)

        m = MyModule()

        # Unexpected key & strict = True
        with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
            state_dict = m.state_dict()
            state_dict["fc1.bad_suffix"] = torch.randn(5, 10)
            m.load_state_dict(state_dict)

        # Unexpected key & strict = False
        state_dict = m.load_state_dict(state_dict, strict=False)
        self.assertIn("fc1.bad_suffix", state_dict.unexpected_keys)

        # Unexpected key whose prefix matches a valid key & strict = True
        with self.assertRaisesRegex(RuntimeError, "Unexpected key"):
            state_dict = m.state_dict()
            state_dict["fc1.weight.bad_suffix"] = torch.randn(5, 10)
            m.load_state_dict(state_dict)

        # Unexpected key whose prefix matches a valid key & strict = False
        state_dict = m.load_state_dict(state_dict, strict=False)
        self.assertIn("fc1.weight.bad_suffix", state_dict.unexpected_keys)


def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
    kwargs = {} if kwargs is None else kwargs

    def module_load(dest, src, assign=False):
        if isinstance(dest, cls):
            if assign:
                return src.detach()
            else:
                if type(src) is torch.Tensor:
                    return cls(src)
                elif type(src) is cls:
                    return src.detach()
                else:
                    if isinstance(src, MyWrapperLoadTensor):
                        return cls(src._data)
                    return cls(src)
        else:
            assert isinstance(
                src, cls
            ), f"Expected isinstance(src, {cls}) but got {type(src)}"
            assert (
                type(dest) == torch.Tensor
                or type(dest) == torch.nn.Parameter
                or issubclass(cls, type(dest))
            )
            if assign:
                return src.detach()
            else:
                if isinstance(src, MyWrapperLoadTensor):
                    if type(dest) not in {torch.Tensor, torch.nn.Parameter}:
                        return type(dest)(src._data)
                    else:
                        return src._data.detach()
                else:
                    return torch.Tensor(src)

    if func is torch.Tensor.module_load:
        return module_load(*args, **kwargs)
    else:
        with torch._C.DisableTorchFunctionSubclass():
            # detach must return instance of same subclass for nn.Parameter()
            if func == torch.Tensor.detach:
                ret = func(*args, **kwargs)
                if not isinstance(ret, cls):
                    return cls(ret)
                return ret
            return func(*args, **kwargs)


class MyLoadTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        return load_torch_function_handler(cls, func, types, args, kwargs)


# We use MyLoadTensor2 to test tensor subclass, wrapper tensor subclass
# where neither inherits from each other
class MyLoadTensor2(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        return load_torch_function_handler(cls, func, types, args, kwargs)


class MyBrokenLoadTensor(torch.Tensor):
    @classmethod
    def __torch_function__(cls, func, types, args=(), kwargs=None):
        kwargs = {} if kwargs is None else kwargs

        if func is torch.Tensor.module_load:
            # wrong as this doesn't detach!
            return args[1]
        else:
            with torch._C.DisableTorchFunctionSubclass():
                # detach must return instance of same subclass for nn.Parameter()
                if func == torch.Tensor.detach:
                    return cls(func(*args, **kwargs))
                return func(*args, **kwargs)


class MyWrapperLoadTensor(MyLoadTensor):
    @staticmethod
    def __new__(cls, data: torch.Tensor):
        t = torch.Tensor._make_wrapper_subclass(
            cls,
            data.size(),
            dtype=data.dtype,
            layout=data.layout,
            device=data.device,
            requires_grad=data.requires_grad,
            strides=data.stride(),
            storage_offset=data.storage_offset(),
        )
        return t

    def __init__(self, data: torch.Tensor):
        self._data = data

    def __repr__(self):
        return f"MyWrapperLoadTensor({self._data.__repr__()})"

    @classmethod
    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            return t._data if isinstance(t, MyWrapperLoadTensor) else t

        def wrap(t):
            return MyWrapperLoadTensor(t) if isinstance(t, torch.Tensor) else t

        kwargs = {} if kwargs is None else kwargs
        out = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
        return tree_map(wrap, out)


class TestLoadStateDictSwap(TestCase):
    @skipIfCrossRef
    @skipIfTorchDynamo("Can't swap with dynamo as dynamo installs weakrefs")
    @swap([True])
    @parametrize("assign", [True, False])
    def test_swap_subclass(self, assign):
        def _create_model(subclass=None):
            m = torch.nn.Linear(2, 3, bias=False)
            m.buf = torch.nn.Buffer(torch.randn(2, 3))
            if subclass is not None:
                m.weight = torch.nn.Parameter(subclass(m.weight))
                m.buf = subclass(m.buf)
            return m

        def _test(m_subclass=None, sd_subclass=None):
            m = _create_model(m_subclass)
            sd = _create_model(sd_subclass).state_dict()
            m.load_state_dict(sd, assign=assign)
            self.assertEqual(m.weight, sd["weight"])
            self.assertEqual(m.buf, sd["buf"])
            self.assertTrue(isinstance(m.weight, torch.nn.Parameter))
            self.assertTrue(not isinstance(m.buf, torch.nn.Parameter))

            weight_type, buf_type = (torch.nn.Parameter, torch.Tensor)
            if assign:
                if sd_subclass is not None:
                    weight_type, buf_type = (sd_subclass, sd_subclass)
            else:
                if m_subclass is not None:
                    weight_type, buf_type = (m_subclass, m_subclass)

            self.assertTrue(type(m.weight) is weight_type)
            self.assertTrue(type(m.buf) is buf_type)

        # (MyLoadTensor, MyWrapperLoadTensor) tests the behavior of (superclass, subclass)
        subclasses = [None, MyLoadTensor, MyLoadTensor2, MyWrapperLoadTensor]
        for m_s, sd_s in product(subclasses, subclasses):
            _test(m_s, sd_s)

        # MyBrokenLoadTensor should error since its module_load doesn't call .detach()
        with self.assertRaisesRegex(
            RuntimeError, re.escape("Error(s) in loading state_dict for Linear:")
        ):
            _test(None, MyBrokenLoadTensor)


instantiate_parametrized_tests(TestLoadStateDict)
instantiate_parametrized_tests(TestLoadStateDictSwap)

if __name__ == "__main__":
    TestCase._default_dtype_check_enabled = True
    run_tests()
