#!/usr/bin/env python3
# Owner(s): ["oncall: r2p"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.abs

import json
import logging
from dataclasses import asdict
from unittest.mock import patch

from torch.distributed.elastic.events import (
    _get_or_create_logger,
    construct_and_record_rdzv_event,
    Event,
    EventSource,
    NodeState,
    RdzvEvent,
)
from torch.testing._internal.common_utils import run_tests, TestCase


class EventLibTest(TestCase):
    def assert_event(self, actual_event, expected_event):
        self.assertEqual(actual_event.name, expected_event.name)
        self.assertEqual(actual_event.source, expected_event.source)
        self.assertEqual(actual_event.timestamp, expected_event.timestamp)
        self.assertDictEqual(actual_event.metadata, expected_event.metadata)

    @patch("torch.distributed.elastic.events.get_logging_handler")
    def test_get_or_create_logger(self, logging_handler_mock):
        logging_handler_mock.return_value = logging.NullHandler()
        logger = _get_or_create_logger("test_destination")
        self.assertIsNotNone(logger)
        self.assertEqual(1, len(logger.handlers))
        self.assertIsInstance(logger.handlers[0], logging.NullHandler)

    def test_event_created(self):
        event = Event(
            name="test_event",
            source=EventSource.AGENT,
            metadata={"key1": "value1", "key2": 2},
        )
        self.assertEqual("test_event", event.name)
        self.assertEqual(EventSource.AGENT, event.source)
        self.assertDictEqual({"key1": "value1", "key2": 2}, event.metadata)

    def test_event_deser(self):
        event = Event(
            name="test_event",
            source=EventSource.AGENT,
            metadata={"key1": "value1", "key2": 2, "key3": 1.0},
        )
        json_event = event.serialize()
        deser_event = Event.deserialize(json_event)
        self.assert_event(event, deser_event)


class RdzvEventLibTest(TestCase):
    @patch("torch.distributed.elastic.events.record_rdzv_event")
    @patch("torch.distributed.elastic.events.get_logging_handler")
    def test_construct_and_record_rdzv_event(self, get_mock, record_mock):
        get_mock.return_value = logging.StreamHandler()
        construct_and_record_rdzv_event(
            run_id="test_run_id",
            message="test_message",
            node_state=NodeState.RUNNING,
        )
        record_mock.assert_called_once()

    @patch("torch.distributed.elastic.events.record_rdzv_event")
    @patch("torch.distributed.elastic.events.get_logging_handler")
    def test_construct_and_record_rdzv_event_does_not_run_if_invalid_dest(
        self, get_mock, record_mock
    ):
        get_mock.return_value = logging.NullHandler()
        construct_and_record_rdzv_event(
            run_id="test_run_id",
            message="test_message",
            node_state=NodeState.RUNNING,
        )
        record_mock.assert_not_called()

    def assert_rdzv_event(self, actual_event: RdzvEvent, expected_event: RdzvEvent):
        self.assertEqual(actual_event.name, expected_event.name)
        self.assertEqual(actual_event.run_id, expected_event.run_id)
        self.assertEqual(actual_event.message, expected_event.message)
        self.assertEqual(actual_event.hostname, expected_event.hostname)
        self.assertEqual(actual_event.pid, expected_event.pid)
        self.assertEqual(actual_event.node_state, expected_event.node_state)
        self.assertEqual(actual_event.master_endpoint, expected_event.master_endpoint)
        self.assertEqual(actual_event.rank, expected_event.rank)
        self.assertEqual(actual_event.local_id, expected_event.local_id)
        self.assertEqual(actual_event.error_trace, expected_event.error_trace)

    def get_test_rdzv_event(self) -> RdzvEvent:
        return RdzvEvent(
            name="test_name",
            run_id="test_run_id",
            message="test_message",
            hostname="test_hostname",
            pid=1,
            node_state=NodeState.RUNNING,
            master_endpoint="test_master_endpoint",
            rank=3,
            local_id=4,
            error_trace="test_error_trace",
        )

    def test_rdzv_event_created(self):
        event = self.get_test_rdzv_event()
        self.assertEqual(event.name, "test_name")
        self.assertEqual(event.run_id, "test_run_id")
        self.assertEqual(event.message, "test_message")
        self.assertEqual(event.hostname, "test_hostname")
        self.assertEqual(event.pid, 1)
        self.assertEqual(event.node_state, NodeState.RUNNING)
        self.assertEqual(event.master_endpoint, "test_master_endpoint")
        self.assertEqual(event.rank, 3)
        self.assertEqual(event.local_id, 4)
        self.assertEqual(event.error_trace, "test_error_trace")

    def test_rdzv_event_deserialize(self):
        event = self.get_test_rdzv_event()
        json_event = event.serialize()
        deserialized_event = RdzvEvent.deserialize(json_event)
        self.assert_rdzv_event(event, deserialized_event)
        self.assert_rdzv_event(event, RdzvEvent.deserialize(event))

    def test_rdzv_event_str(self):
        event = self.get_test_rdzv_event()
        self.assertEqual(str(event), json.dumps(asdict(event)))


if __name__ == "__main__":
    run_tests()
