# Owner(s): ["oncall: r2p"]

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import subprocess
from base64 import b64encode
from typing import cast, ClassVar
from unittest import TestCase

from etcd import EtcdKeyNotFound  # type: ignore[import]
from rendezvous_backend_test import RendezvousBackendTestMixin

from torch.distributed.elastic.rendezvous import (
    RendezvousConnectionError,
    RendezvousParameters,
)
from torch.distributed.elastic.rendezvous.etcd_rendezvous_backend import (
    create_backend,
    EtcdRendezvousBackend,
)
from torch.distributed.elastic.rendezvous.etcd_server import EtcdServer
from torch.distributed.elastic.rendezvous.etcd_store import EtcdStore


class EtcdRendezvousBackendTest(TestCase, RendezvousBackendTestMixin):
    _server: ClassVar[EtcdServer]

    @classmethod
    def setUpClass(cls) -> None:
        cls._server = EtcdServer()
        cls._server.start(stderr=subprocess.DEVNULL)

    @classmethod
    def tearDownClass(cls) -> None:
        cls._server.stop()

    def setUp(self) -> None:
        self._client = self._server.get_client()

        # Make sure we have a clean slate.
        try:
            self._client.delete("/dummy_prefix", recursive=True, dir=True)
        except EtcdKeyNotFound:
            pass

        self._backend = EtcdRendezvousBackend(
            self._client, "dummy_run_id", "/dummy_prefix"
        )

    def _corrupt_state(self) -> None:
        self._client.write("/dummy_prefix/dummy_run_id", "non_base64")


class CreateBackendTest(TestCase):
    _server: ClassVar[EtcdServer]

    @classmethod
    def setUpClass(cls) -> None:
        cls._server = EtcdServer()
        cls._server.start(stderr=subprocess.DEVNULL)

    @classmethod
    def tearDownClass(cls) -> None:
        cls._server.stop()

    def setUp(self) -> None:
        self._params = RendezvousParameters(
            backend="dummy_backend",
            endpoint=self._server.get_endpoint(),
            run_id="dummy_run_id",
            min_nodes=1,
            max_nodes=1,
            protocol="hTTp",
            read_timeout="10",
        )

        self._expected_read_timeout = 10

    def test_create_backend_returns_backend(self) -> None:
        backend, store = create_backend(self._params)

        self.assertEqual(backend.name, "etcd-v2")

        self.assertIsInstance(store, EtcdStore)

        etcd_store = cast(EtcdStore, store)

        self.assertEqual(etcd_store.client.read_timeout, self._expected_read_timeout)  # type: ignore[attr-defined]

        client = self._server.get_client()

        backend.set_state(b"dummy_state")

        result = client.get("/torch/elastic/rendezvous/" + self._params.run_id)

        self.assertEqual(result.value, b64encode(b"dummy_state").decode())
        self.assertLessEqual(result.ttl, 7200)

        store.set("dummy_key", "dummy_value")

        result = client.get("/torch/elastic/store/" + b64encode(b"dummy_key").decode())

        self.assertEqual(result.value, b64encode(b"dummy_value").decode())

    def test_create_backend_returns_backend_if_protocol_is_not_specified(self) -> None:
        del self._params.config["protocol"]

        self.test_create_backend_returns_backend()

    def test_create_backend_returns_backend_if_read_timeout_is_not_specified(
        self,
    ) -> None:
        del self._params.config["read_timeout"]

        self._expected_read_timeout = 60

        self.test_create_backend_returns_backend()

    def test_create_backend_raises_error_if_etcd_is_unreachable(self) -> None:
        self._params.endpoint = "dummy:1234"

        with self.assertRaisesRegex(
            RendezvousConnectionError,
            r"^The connection to etcd has failed. See inner exception for details.$",
        ):
            create_backend(self._params)

    def test_create_backend_raises_error_if_protocol_is_invalid(self) -> None:
        self._params.config["protocol"] = "dummy"

        with self.assertRaisesRegex(
            ValueError, r"^The protocol must be HTTP or HTTPS.$"
        ):
            create_backend(self._params)

    def test_create_backend_raises_error_if_read_timeout_is_invalid(self) -> None:
        for read_timeout in ["0", "-10"]:
            with self.subTest(read_timeout=read_timeout):
                self._params.config["read_timeout"] = read_timeout

                with self.assertRaisesRegex(
                    ValueError, r"^The read timeout must be a positive integer.$"
                ):
                    create_backend(self._params)
