# Owner(s): ["module: unknown"]
import unittest
from dataclasses import dataclass
from typing import Any, Callable, cast, Tuple, Union

import torch
from torch import nn, optim
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.distributed._tools.runtime_estimator import RuntimeEstimator
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
from torch.testing._internal.distributed._tensor.common_dtensor import (
    ModelArgs,
    Transformer,
)


@dataclass
class ConvArgs:
    image_size: int
    num_classes: int


class SimpleCNN(nn.Module):
    def __init__(self, conv_args: ConvArgs):
        super().__init__()
        image_size = conv_args.image_size
        num_classes = conv_args.num_classes
        self.image_size = image_size
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3)
        self.fc1_size = self._calculate_fc1_size()
        self.fc1 = nn.Linear(self.fc1_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)

    def _calculate_fc1_size(self):
        size = self.image_size
        size = (size - 5 + 1) // 2  # conv1 and pool
        size = (size - 5 + 1) // 2  # conv2 and pool
        size = size - 3 + 1  # conv3
        size = (size - 3 + 1) // 2  # conv4 and pool
        return 512 * size * size

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = nn.functional.relu(self.conv3(x))
        x = self.pool(nn.functional.relu(self.conv4(x)))
        x = x.view(-1, self.fc1_size)
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class TestRuntimeEstimator(TestCase):
    def _train_step(
        self,
        model: nn.Module,
        optimizer: optim.Optimizer,
        inp: torch.Tensor,
    ):
        out = model(inp)
        loss = out.sum()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    def _measure_actual_cuda_time(
        self,
        func: Callable,
        args: Tuple[Any, ...],
    ) -> float:
        warmup_iters, actual_iters = 2, 5
        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)
        for _ in range(warmup_iters):
            func(*args)
        start_event.record()
        for _ in range(actual_iters):
            func(*args)
        end_event.record()
        torch.cuda.synchronize()
        measured_time = start_event.elapsed_time(end_event) / actual_iters
        return measured_time

    def _runtime_estimate(
        self,
        estimate_mode: str,
        func: Callable,
        args: Tuple[Any, ...],
    ) -> float:
        # Optimizer init step
        func(*args)
        runtime_estimator = RuntimeEstimator()
        with runtime_estimator(estimate_mode_type=estimate_mode):
            func(*args)
        return runtime_estimator.total_runtime

    def _init_model_and_args(
        self,
        model_type: str,
        model_args: Union[ConvArgs, ModelArgs],
        bsz: int,
    ) -> Tuple[nn.Module, optim.Optimizer, torch.Tensor]:
        dev = torch.cuda.current_device()
        if model_type == "Transformer":
            model_args = cast(ModelArgs, model_args)
            with torch.device(dev):
                model = Transformer(model_args)
            optimizer = optim.Adam(model.parameters(), lr=1e-2, foreach=True)
            inp = torch.randint(
                0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev
            )
        elif model_type == "CNN":
            model_args = cast(ConvArgs, model_args)
            with torch.device(dev):
                model = SimpleCNN(model_args)
            optimizer = optim.SGD(model.parameters(), lr=1e-2, foreach=True)
            inp = torch.randn(
                bsz, 3, model_args.image_size, model_args.image_size, device=dev
            )
        else:
            raise NotImplementedError("Only Transformer and CNN is supported")
        return (model, optimizer, inp)

    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    def test_transformer_runtime(
        self,
    ):
        """Runs a basic GPT-2 model"""
        vocab_size = 8192
        bsz, seq_len = 8, 1024
        model_args = ModelArgs(
            n_layers=4,
            n_heads=12,
            vocab_size=vocab_size,
            max_seq_len=seq_len,
            dim=768,
            dropout_p=0.1,
        )

        args = self._init_model_and_args("Transformer", model_args, bsz)
        actual_runtime = self._measure_actual_cuda_time(self._train_step, args)
        with FakeTensorMode():
            fake_args = self._init_model_and_args("Transformer", model_args, bsz)
            benchmark_estimate = self._runtime_estimate(
                "operator-level-benchmark", self._train_step, fake_args
            )
            roofline_estimate = self._runtime_estimate(
                "operator-level-cost-model", self._train_step, fake_args
            )
        benchmark_accuracy = actual_runtime / benchmark_estimate
        roofline_accuracy = actual_runtime / roofline_estimate
        print(
            f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}"
            f"\n Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
        )
        self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
        self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3)

    @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653")
    @unittest.skipIf(not TEST_CUDA, "CUDA not available")
    def test_conv_model_runtime(
        self,
    ):
        """Runs a simple CNN model"""
        num_classes = 100
        bsz, img_sz = 256, 128
        model_args = ConvArgs(img_sz, num_classes)
        args = self._init_model_and_args("CNN", model_args, bsz)
        actual_runtime = self._measure_actual_cuda_time(self._train_step, args)
        with FakeTensorMode():
            fake_args = self._init_model_and_args("CNN", model_args, bsz)
            benchmark_estimate = self._runtime_estimate(
                "operator-level-benchmark", self._train_step, fake_args
            )
            roofline_estimate = self._runtime_estimate(
                "operator-level-cost-model", self._train_step, fake_args
            )
        benchmark_accuracy = actual_runtime / benchmark_estimate
        roofline_accuracy = actual_runtime / roofline_estimate
        print(
            f"Actual: {actual_runtime} Benchmark Estimate: {benchmark_estimate} Accuracy: {benchmark_accuracy}\n"
            f"Actual: {actual_runtime} Roofline Estimatee: {roofline_estimate} Accuracy: {roofline_accuracy}"
        )
        self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2)
        self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4)


if __name__ == "__main__":
    run_tests()
