# Owner(s): ["oncall: distributed"]

import json
import logging
import os
import re
import sys
import time
from functools import partial, wraps

import torch
import torch.distributed as dist
from torch.distributed.c10d_logger import _c10d_logger, _exception_logger, _time_logger


if not dist.is_available():
    print("Distributed not available, skipping tests", file=sys.stderr)
    sys.exit(0)

from torch.testing._internal.common_distributed import MultiProcessTestCase, TEST_SKIPS
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN


if TEST_WITH_DEV_DBG_ASAN:
    print(
        "Skip dev-asan as torch + multiprocessing spawn have known issues",
        file=sys.stderr,
    )
    sys.exit(0)

BACKEND = dist.Backend.NCCL
WORLD_SIZE = min(4, max(2, torch.cuda.device_count()))


def with_comms(func=None):
    if func is None:
        return partial(
            with_comms,
        )

    @wraps(func)
    def wrapper(self, *args, **kwargs):
        if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
        self.dist_init()
        func(self)
        self.destroy_comms()

    return wrapper


class C10dErrorLoggerTest(MultiProcessTestCase):
    def setUp(self):
        super().setUp()
        os.environ["WORLD_SIZE"] = str(self.world_size)
        os.environ["BACKEND"] = BACKEND
        self._spawn_processes()

    @property
    def device(self):
        return (
            torch.device(self.rank)
            if BACKEND == dist.Backend.NCCL
            else torch.device("cpu")
        )

    @property
    def world_size(self):
        return WORLD_SIZE

    @property
    def process_group(self):
        return dist.group.WORLD

    def destroy_comms(self):
        # Wait for all ranks to reach here before starting shutdown.
        dist.barrier()
        dist.destroy_process_group()

    def dist_init(self):
        dist.init_process_group(
            backend=BACKEND,
            world_size=self.world_size,
            rank=self.rank,
            init_method=f"file://{self.file_name}",
        )

        # set device for nccl pg for collectives
        if BACKEND == "nccl":
            torch.cuda.set_device(self.rank)

    def test_get_or_create_logger(self):
        self.assertIsNotNone(_c10d_logger)
        self.assertEqual(1, len(_c10d_logger.handlers))
        self.assertIsInstance(_c10d_logger.handlers[0], logging.NullHandler)

    @_exception_logger
    def _failed_broadcast_raise_exception(self):
        tensor = torch.arange(2, dtype=torch.int64)
        dist.broadcast(tensor, self.world_size + 1)

    @_exception_logger
    def _failed_broadcast_not_raise_exception(self):
        try:
            tensor = torch.arange(2, dtype=torch.int64)
            dist.broadcast(tensor, self.world_size + 1)
        except Exception:
            pass

    @with_comms
    def test_exception_logger(self) -> None:
        with self.assertRaises(Exception):
            self._failed_broadcast_raise_exception()

        with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
            self._failed_broadcast_not_raise_exception()
            error_msg_dict = json.loads(
                re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
            )

            self.assertEqual(len(error_msg_dict), 10)

            self.assertIn("pg_name", error_msg_dict.keys())
            self.assertEqual("None", error_msg_dict["pg_name"])

            self.assertIn("func_name", error_msg_dict.keys())
            self.assertEqual("broadcast", error_msg_dict["func_name"])

            self.assertIn("args", error_msg_dict.keys())

            self.assertIn("backend", error_msg_dict.keys())
            self.assertEqual("nccl", error_msg_dict["backend"])

            self.assertIn("nccl_version", error_msg_dict.keys())
            nccl_ver = torch.cuda.nccl.version()
            self.assertEqual(
                ".".join(str(v) for v in nccl_ver), error_msg_dict["nccl_version"]
            )

            # In this test case, group_size = world_size, since we don't have multiple processes on one node.
            self.assertIn("group_size", error_msg_dict.keys())
            self.assertEqual(str(self.world_size), error_msg_dict["group_size"])

            self.assertIn("world_size", error_msg_dict.keys())
            self.assertEqual(str(self.world_size), error_msg_dict["world_size"])

            self.assertIn("global_rank", error_msg_dict.keys())
            self.assertIn(str(dist.get_rank()), error_msg_dict["global_rank"])

            # In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
            self.assertIn("local_rank", error_msg_dict.keys())
            self.assertIn(str(dist.get_rank()), error_msg_dict["local_rank"])

    @_time_logger
    def _dummy_sleep(self):
        time.sleep(5)

    @with_comms
    def test_time_logger(self) -> None:
        with self.assertLogs(_c10d_logger, level="DEBUG") as captured:
            self._dummy_sleep()
            msg_dict = json.loads(
                re.search("({.+})", captured.output[0]).group(0).replace("'", '"')
            )
            self.assertEqual(len(msg_dict), 10)

            self.assertIn("pg_name", msg_dict.keys())
            self.assertEqual("None", msg_dict["pg_name"])

            self.assertIn("func_name", msg_dict.keys())
            self.assertEqual("_dummy_sleep", msg_dict["func_name"])

            self.assertIn("args", msg_dict.keys())

            self.assertIn("backend", msg_dict.keys())
            self.assertEqual("nccl", msg_dict["backend"])

            self.assertIn("nccl_version", msg_dict.keys())
            nccl_ver = torch.cuda.nccl.version()
            self.assertEqual(
                ".".join(str(v) for v in nccl_ver), msg_dict["nccl_version"]
            )

            # In this test case, group_size = world_size, since we don't have multiple processes on one node.
            self.assertIn("group_size", msg_dict.keys())
            self.assertEqual(str(self.world_size), msg_dict["group_size"])

            self.assertIn("world_size", msg_dict.keys())
            self.assertEqual(str(self.world_size), msg_dict["world_size"])

            self.assertIn("global_rank", msg_dict.keys())
            self.assertIn(str(dist.get_rank()), msg_dict["global_rank"])

            # In this test case, local_rank = global_rank, since we don't have multiple processes on one node.
            self.assertIn("local_rank", msg_dict.keys())
            self.assertIn(str(dist.get_rank()), msg_dict["local_rank"])

            self.assertIn("time_spent", msg_dict.keys())
            time_ns = re.findall(r"\d+", msg_dict["time_spent"])[0]
            self.assertEqual(5, int(float(time_ns) / pow(10, 9)))


if __name__ == "__main__":
    run_tests()
