# Copyright 2019 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Logging helper module."""


# System modules
import os.path
import sys
import traceback
from typing import Union


# TODO(yunlian@google.com): Use GetRoot from misc
def GetRoot(scr_name):
    """Break up pathname into (dir+name)."""
    abs_path = os.path.abspath(scr_name)
    return (os.path.dirname(abs_path), os.path.basename(abs_path))


class Logger:
    """Logging helper class."""

    MAX_LOG_FILES = 10

    def __init__(self, rootdir, basefilename, print_console, subdir="logs"):
        logdir = os.path.join(rootdir, subdir)
        basename = os.path.join(logdir, basefilename)

        try:
            os.makedirs(logdir)
        except OSError:
            pass
            # print("Warning: Logs directory '%s' already exists." % logdir)

        self.print_console = print_console

        self._CreateLogFileHandles(basename)

        self._WriteTo(self.cmdfd, " ".join(sys.argv), True)

    def _AddSuffix(self, basename, suffix):
        return "%s%s" % (basename, suffix)

    def _FindSuffix(self, basename):
        timestamps = []
        found_suffix = None
        for i in range(self.MAX_LOG_FILES):
            suffix = str(i)
            suffixed_basename = self._AddSuffix(basename, suffix)
            cmd_file = "%s.cmd" % suffixed_basename
            if not os.path.exists(cmd_file):
                found_suffix = suffix
                break
            timestamps.append(os.stat(cmd_file).st_mtime)

        if found_suffix:
            return found_suffix

        # Try to pick the oldest file with the suffix and return that one.
        suffix = str(timestamps.index(min(timestamps)))
        # print ("Warning: Overwriting log file: %s" %
        #       self._AddSuffix(basename, suffix))
        return suffix

    def _CreateLogFileHandle(self, name):
        fd = None
        try:
            fd = open(name, "w", encoding="utf-8")
        except IOError:
            print("Warning: could not open %s for writing." % name)
        return fd

    def _CreateLogFileHandles(self, basename):
        suffix = self._FindSuffix(basename)
        suffixed_basename = self._AddSuffix(basename, suffix)

        self.cmdfd = self._CreateLogFileHandle("%s.cmd" % suffixed_basename)
        self.stdout = self._CreateLogFileHandle("%s.out" % suffixed_basename)
        self.stderr = self._CreateLogFileHandle("%s.err" % suffixed_basename)

        self._CreateLogFileSymlinks(basename, suffixed_basename)

    # Symlink unsuffixed basename to currently suffixed one.
    def _CreateLogFileSymlinks(self, basename, suffixed_basename):
        try:
            for extension in ["cmd", "out", "err"]:
                src_file = "%s.%s" % (
                    os.path.basename(suffixed_basename),
                    extension,
                )
                dest_file = "%s.%s" % (basename, extension)
                if os.path.exists(dest_file):
                    os.remove(dest_file)
                os.symlink(src_file, dest_file)
        except Exception as ex:
            print("Exception while creating symlinks: %s" % str(ex))

    def _WriteTo(self, fd, msg, flush):
        if fd:
            fd.write(msg)
            if flush:
                fd.flush()

    def LogStartDots(self, print_to_console=True):
        term_fd = self._GetStdout(print_to_console)
        if term_fd:
            term_fd.flush()
            term_fd.write(". ")
            term_fd.flush()

    def LogAppendDot(self, print_to_console=True):
        term_fd = self._GetStdout(print_to_console)
        if term_fd:
            term_fd.write(". ")
            term_fd.flush()

    def LogEndDots(self, print_to_console=True):
        term_fd = self._GetStdout(print_to_console)
        if term_fd:
            term_fd.write("\n")
            term_fd.flush()

    def LogMsg(self, file_fd, term_fd, msg, flush=True):
        if file_fd:
            self._WriteTo(file_fd, msg, flush)
        if self.print_console:
            self._WriteTo(term_fd, msg, flush)

    def _GetStdout(self, print_to_console):
        if print_to_console:
            return sys.stdout
        return None

    def _GetStderr(self, print_to_console):
        if print_to_console:
            return sys.stderr
        return None

    def LogCmdToFileOnly(self, cmd, machine="", user=None):
        if not self.cmdfd:
            return

        host = ("%s@%s" % (user, machine)) if user else machine
        flush = True
        cmd_string = "CMD (%s): %s\n" % (host, cmd)
        self._WriteTo(self.cmdfd, cmd_string, flush)

    def LogCmd(self, cmd, machine="", user=None, print_to_console=True):
        if user:
            host = "%s@%s" % (user, machine)
        else:
            host = machine

        self.LogMsg(
            self.cmdfd,
            self._GetStdout(print_to_console),
            "CMD (%s): %s\n" % (host, cmd),
        )

    def LogFatal(self, msg, print_to_console=True):
        self.LogMsg(
            self.stderr, self._GetStderr(print_to_console), "FATAL: %s\n" % msg
        )
        self.LogMsg(
            self.stderr,
            self._GetStderr(print_to_console),
            "\n".join(traceback.format_stack()),
        )
        sys.exit(1)

    def LogError(self, msg, print_to_console=True):
        self.LogMsg(
            self.stderr, self._GetStderr(print_to_console), "ERROR: %s\n" % msg
        )

    def LogWarning(self, msg, print_to_console=True):
        self.LogMsg(
            self.stderr,
            self._GetStderr(print_to_console),
            "WARNING: %s\n" % msg,
        )

    def LogOutput(self, msg, print_to_console=True):
        self.LogMsg(
            self.stdout, self._GetStdout(print_to_console), "OUTPUT: %s\n" % msg
        )

    def LogFatalIf(self, condition, msg):
        if condition:
            self.LogFatal(msg)

    def LogErrorIf(self, condition, msg):
        if condition:
            self.LogError(msg)

    def LogWarningIf(self, condition, msg):
        if condition:
            self.LogWarning(msg)

    def LogCommandOutput(self, msg, print_to_console=True):
        self.LogMsg(
            self.stdout, self._GetStdout(print_to_console), msg, flush=False
        )

    def LogCommandError(self, msg, print_to_console=True):
        self.LogMsg(
            self.stderr, self._GetStderr(print_to_console), msg, flush=False
        )

    def Flush(self):
        self.cmdfd.flush()
        self.stdout.flush()
        self.stderr.flush()


class MockLogger:
    """Logging helper class."""

    MAX_LOG_FILES = 10

    def __init__(self, *_args, **_kwargs):
        self.stdout = sys.stdout
        self.stderr = sys.stderr

    def _AddSuffix(self, basename, suffix):
        return "%s%s" % (basename, suffix)

    def _FindSuffix(self, basename):
        timestamps = []
        found_suffix = None
        for i in range(self.MAX_LOG_FILES):
            suffix = str(i)
            suffixed_basename = self._AddSuffix(basename, suffix)
            cmd_file = "%s.cmd" % suffixed_basename
            if not os.path.exists(cmd_file):
                found_suffix = suffix
                break
            timestamps.append(os.stat(cmd_file).st_mtime)

        if found_suffix:
            return found_suffix

        # Try to pick the oldest file with the suffix and return that one.
        suffix = str(timestamps.index(min(timestamps)))
        # print ("Warning: Overwriting log file: %s" %
        #       self._AddSuffix(basename, suffix))
        return suffix

    def _CreateLogFileHandle(self, name):
        print("MockLogger: creating open file handle for %s (writing)" % name)

    def _CreateLogFileHandles(self, basename):
        suffix = self._FindSuffix(basename)
        suffixed_basename = self._AddSuffix(basename, suffix)

        print("MockLogger: opening file %s.cmd" % suffixed_basename)
        print("MockLogger: opening file %s.out" % suffixed_basename)
        print("MockLogger: opening file %s.err" % suffixed_basename)

        self._CreateLogFileSymlinks(basename, suffixed_basename)

    # Symlink unsuffixed basename to currently suffixed one.
    def _CreateLogFileSymlinks(self, basename, suffixed_basename):
        for extension in ["cmd", "out", "err"]:
            src_file = "%s.%s" % (
                os.path.basename(suffixed_basename),
                extension,
            )
            dest_file = "%s.%s" % (basename, extension)
            print(
                "MockLogger: Calling os.symlink(%s, %s)" % (src_file, dest_file)
            )

    def _WriteTo(self, _fd, msg, _flush):
        print("MockLogger: %s" % msg)

    def LogStartDots(self, _print_to_console=True):
        print(". ")

    def LogAppendDot(self, _print_to_console=True):
        print(". ")

    def LogEndDots(self, _print_to_console=True):
        print("\n")

    def LogMsg(self, _file_fd, _term_fd, msg, **_kwargs):
        print("MockLogger: %s" % msg)

    def _GetStdout(self, _print_to_console):
        return None

    def _GetStderr(self, _print_to_console):
        return None

    def LogCmdToFileOnly(self, *_args, **_kwargs):
        return

    # def LogCmdToFileOnly(self, cmd, machine='', user=None):
    #   host = ('%s@%s' % (user, machine)) if user else machine
    #   cmd_string = 'CMD (%s): %s\n' % (host, cmd)
    #   print('MockLogger: Writing to file ONLY: %s' % cmd_string)

    def LogCmd(self, cmd, machine="", user=None, print_to_console=True):
        if user:
            host = "%s@%s" % (user, machine)
        else:
            host = machine

        self.LogMsg(
            0, self._GetStdout(print_to_console), "CMD (%s): %s\n" % (host, cmd)
        )

    def LogFatal(self, msg, print_to_console=True):
        self.LogMsg(0, self._GetStderr(print_to_console), "FATAL: %s\n" % msg)
        self.LogMsg(
            0,
            self._GetStderr(print_to_console),
            "\n".join(traceback.format_stack()),
        )
        print("MockLogger: Calling sysexit(1)")

    def LogError(self, msg, print_to_console=True):
        self.LogMsg(0, self._GetStderr(print_to_console), "ERROR: %s\n" % msg)

    def LogWarning(self, msg, print_to_console=True):
        self.LogMsg(0, self._GetStderr(print_to_console), "WARNING: %s\n" % msg)

    def LogOutput(self, msg, print_to_console=True):
        self.LogMsg(0, self._GetStdout(print_to_console), "OUTPUT: %s\n" % msg)

    def LogFatalIf(self, condition, msg):
        if condition:
            self.LogFatal(msg)

    def LogErrorIf(self, condition, msg):
        if condition:
            self.LogError(msg)

    def LogWarningIf(self, condition, msg):
        if condition:
            self.LogWarning(msg)

    def LogCommandOutput(self, msg, print_to_console=True):
        self.LogMsg(
            self.stdout, self._GetStdout(print_to_console), msg, flush=False
        )

    def LogCommandError(self, msg, print_to_console=True):
        self.LogMsg(
            self.stderr, self._GetStderr(print_to_console), msg, flush=False
        )

    def Flush(self):
        print("MockLogger: Flushing cmdfd, stdout, stderr")


main_logger = None


def InitLogger(script_name, log_dir, print_console=True, mock=False):
    """Initialize a global logger. To be called only once."""
    # pylint: disable=global-statement
    global main_logger
    if main_logger:
        return main_logger
    rootdir, basefilename = GetRoot(script_name)
    if not log_dir:
        log_dir = rootdir
    if not mock:
        main_logger = Logger(log_dir, basefilename, print_console)
    else:
        main_logger = MockLogger(log_dir, basefilename, print_console)
    return main_logger


def GetLogger(log_dir="", mock=False) -> Union[Logger, MockLogger]:
    return InitLogger(sys.argv[0], log_dir, mock=mock)


def HandleUncaughtExceptions(fun):
    """Catches all exceptions that would go outside decorated fun scope."""

    def _Interceptor(*args, **kwargs):
        try:
            return fun(*args, **kwargs)
        except Exception:
            GetLogger().LogFatal(
                "Uncaught exception:\n%s" % traceback.format_exc()
            )

    return _Interceptor
