# Copyright 2024, The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes used to handle banners."""

from __future__ import annotations

from datetime import date
import json
import logging
from pathlib import Path
from typing import Any, Callable

from atest import atest_utils
from atest import constants


class BannerHistory:
  """A history for banner handling."""

  _LAST_BANNER_PROMPT_DATE = 'last_banner_prompt_date'

  @staticmethod
  def create(config_dir: Path) -> BannerHistory:
    config_dir.mkdir(parents=True, exist_ok=True)
    history_file = config_dir.joinpath('banner.json')

    if not history_file.exists():
      history_file.touch()
      history = {}
    else:
      try:
        history = json.loads(history_file.read_text())
      except json.JSONDecodeError as e:
        atest_utils.print_and_log_error(
            'Banner history json file is in a bad format: %s', e
        )
        history = {}

    return BannerHistory(history_file, history)

  def __init__(self, history_file: Path, history: dict):
    self._history_file = history_file
    self._history = history

  def get_last_banner_prompt_date(self) -> str:
    """Get the last date when banner was prompt."""
    return self._history.get(BannerHistory._LAST_BANNER_PROMPT_DATE, '')

  def set_last_banner_prompt_date(self, date: str):
    """Set the last date when banner was prompt."""
    self._history[BannerHistory._LAST_BANNER_PROMPT_DATE] = date
    self._history_file.write_text(json.dumps(self._history))


class BannerPrinter:
  """A printer used to collect and print banners."""

  @staticmethod
  def create() -> BannerPrinter:
    return BannerPrinter(atest_utils.get_config_folder())

  def __init__(self, config_dir: Path):
    self._messages = []
    self._config_dir = config_dir

  def register(self, message: str):
    """Register a banner message."""
    self._messages.append(message)

  def print(self, print_func: Callable = None, date_supplier: Callable = None):
    """Print the banners."""

    if not self._messages:
      return

    if not print_func:
      print_func = lambda m: atest_utils.colorful_print(m, constants.MAGENTA)

    if not date_supplier:
      date_supplier = lambda: str(date.today())

    today = date_supplier()
    history = BannerHistory.create(self._config_dir)
    if history.get_last_banner_prompt_date() != today:
      for message in self._messages:
        print_func(message)

      history.set_last_banner_prompt_date(today)
