# Owner(s): ["oncall: jit"]

import re

import torch
import torch._lazy.metrics as metrics
import torch._lazy.ts_backend
from torch.testing._internal.common_utils import run_tests, TestCase


torch._lazy.ts_backend.init()

NODE_TYPE_PATTERN = re.compile(r", NodeType=[^\n]+")


class LazyFuncionalizationTest(TestCase):
    def test_lazy_init_with_view(self):
        def f(device, reset_storage=False):
            torch.manual_seed(2023)

            if device == "lazy":
                metrics.reset()

            class Model(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.fc1 = torch.nn.Linear(4, 2, bias=False)

                def forward(self, x):
                    return x @ self.fc1.weight.transpose(0, 1)

            with torch.device(device):
                model = Model()

                if device == "lazy":
                    if reset_storage:
                        torch._C._unsafe_reset_storage(model.fc1.weight)

                    torch._lazy.mark_step()

                    sync_tensors = metrics.counter_value("SyncedTensorsWithIR")
                    if reset_storage:
                        assert sync_tensors == 1
                    else:
                        # There is an extra tensor being unnecessarily synced if
                        # the functional storage is not reset.
                        assert sync_tensors == 2

                x = torch.ones(4)
                out = model(x)

                if device == "lazy":
                    torch._lazy.mark_step()

                return out

        cpu_out = f("cpu")
        lazy_out_1 = f("lazy", reset_storage=False)
        lazy_out_2 = f("lazy", reset_storage=True)

        self.assertEqual(cpu_out, lazy_out_1.to("cpu"))
        self.assertEqual(cpu_out, lazy_out_2.to("cpu"))

    def test_data_assign(self):
        def text(lazyt):
            raw = torch._C._lazy._get_tensors_text([lazyt])
            return NODE_TYPE_PATTERN.sub("", raw)

        origin = torch.rand(3, dtype=torch.float32)
        tensor = origin.to("lazy")

        self.assertExpectedInline(
            text(tensor),
            """\
IR {
  %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0, ROOT=0
}
""",
        )

        # Modify the data-type of tensor, and assign it to 'data'.
        # This should update the inner tensor of FunctionalTensorWrapper,
        # changing the corresponding IR node.
        modified_tensor = tensor.to(torch.bfloat16)
        tensor.data = modified_tensor

        self.assertExpectedInline(
            text(tensor),
            """\
IR {
  %0 = [Float[3]] lazy_tensors::device_data(), device=CPU0
  %1 = [BFloat16[3]] aten::_to_copy(%0), dtype=BFloat16, layout=null, device=null, pin_memory=null, non_blocking=0, memory_format=null, ROOT=0
}
""",  # noqa: B950
        )


if __name__ == "__main__":
    run_tests()
