#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 The ChromiumOS Authors
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

"""Module of binary serch for perforce."""

import argparse
import math
import os
import re
import sys
import tempfile

from cros_utils import command_executer
from cros_utils import logger


verbose = True


def _GetP4ClientSpec(client_name, p4_paths):
    p4_string = ""
    for p4_path in p4_paths:
        if " " not in p4_path:
            p4_string += " -a %s" % p4_path
        else:
            p4_string += (
                ' -a "' + (" //" + client_name + "/").join(p4_path) + '"'
            )

    return p4_string


def GetP4Command(client_name, p4_port, p4_paths, checkoutdir, p4_snapshot=""):
    command = ""

    if p4_snapshot:
        command += "mkdir -p " + checkoutdir
        for p4_path in p4_paths:
            real_path = p4_path[1]
            if real_path.endswith("..."):
                real_path = real_path.replace("/...", "")
                command += (
                    "; mkdir -p "
                    + checkoutdir
                    + "/"
                    + os.path.dirname(real_path)
                )
                command += (
                    "&& rsync -lr "
                    + p4_snapshot
                    + "/"
                    + real_path
                    + " "
                    + checkoutdir
                    + "/"
                    + os.path.dirname(real_path)
                )
        return command

    command += " export P4CONFIG=.p4config"
    command += " && mkdir -p " + checkoutdir
    command += " && cd " + checkoutdir
    command += " && cp ${HOME}/.p4config ."
    command += " && chmod u+w .p4config"
    command += ' && echo "P4PORT=' + p4_port + '" >> .p4config'
    command += ' && echo "P4CLIENT=' + client_name + '" >> .p4config'
    command += " && g4 client " + _GetP4ClientSpec(client_name, p4_paths)
    command += " && g4 sync "
    command += " && cd -"
    return command


class BinarySearchPoint(object):
    """Class of binary search point."""

    def __init__(self, revision, status, tag=None):
        self.revision = revision
        self.status = status
        self.tag = tag


class BinarySearcherForPass(object):
    """Class of pass level binary searcher."""

    def __init__(self, logger_to_set=None):
        self.current = 0
        self.lo = 0
        self.hi = 0
        self.total = 0
        if logger_to_set is not None:
            self.logger = logger_to_set
        else:
            self.logger = logger.GetLogger()

    def GetNext(self):
        # For the first run, update self.hi with total pass/transformation count
        if self.hi == 0:
            self.hi = self.total
        self.current = (self.hi + self.lo) // 2
        message = "Bisecting between: (%d, %d)" % (self.lo, self.hi)
        self.logger.LogOutput(message, print_to_console=verbose)
        message = "Current limit number: %d" % self.current
        self.logger.LogOutput(message, print_to_console=verbose)
        return self.current

    def SetStatus(self, status):
        """Set lo/hi status based on test script result

        If status == 0, it means that runtime error is not introduced until current
        pass/transformation, so we need to increase lower bound for binary search.

        If status == 1, it means that runtime error still happens with current pass/
        transformation, so we need to decrease upper bound for binary search.

        Returns:
          True if we find the bad pass/transformation, or cannot find bad one after
          decreasing to the first pass/transformation. Otherwise False.
        """
        assert status in (0, 1, 125), status

        if self.current == 0:
            message = (
                "Runtime error occurs before first pass/transformation. "
                "Stop binary searching."
            )
            self.logger.LogOutput(message, print_to_console=verbose)
            return True

        if status == 0:
            message = "Runtime error is not reproduced, increasing lower bound."
            self.logger.LogOutput(message, print_to_console=verbose)
            self.lo = self.current + 1
        elif status == 1:
            message = "Runtime error is reproduced, decreasing upper bound.."
            self.logger.LogOutput(message, print_to_console=verbose)
            self.hi = self.current

        if self.lo >= self.hi:
            return True

        return False


class BinarySearcher(object):
    """Class of binary searcher."""

    def __init__(self, logger_to_set=None):
        self.sorted_list = []
        self.index_log = []
        self.status_log = []
        self.skipped_indices = []
        self.current = 0
        self.points = {}
        self.lo = 0
        self.hi = 0
        if logger_to_set is not None:
            self.logger = logger_to_set
        else:
            self.logger = logger.GetLogger()

    def SetSortedList(self, sorted_list):
        assert sorted_list
        self.sorted_list = sorted_list
        self.index_log = []
        self.hi = len(sorted_list) - 1
        self.lo = 0
        self.points = {}
        for i in range(len(self.sorted_list)):
            bsp = BinarySearchPoint(self.sorted_list[i], -1, "Not yet done.")
            self.points[i] = bsp

    def SetStatus(self, status, tag=None):
        message = "Revision: %s index: %d returned: %d" % (
            self.sorted_list[self.current],
            self.current,
            status,
        )
        self.logger.LogOutput(message, print_to_console=verbose)
        assert status in (0, 1, 125), status
        self.index_log.append(self.current)
        self.status_log.append(status)
        bsp = BinarySearchPoint(self.sorted_list[self.current], status, tag)
        self.points[self.current] = bsp

        if status == 125:
            self.skipped_indices.append(self.current)

        if status in (0, 1):
            if status == 0:
                self.lo = self.current + 1
            elif status == 1:
                self.hi = self.current
            self.logger.LogOutput("lo: %d hi: %d\n" % (self.lo, self.hi))
            self.current = (self.lo + self.hi) // 2

        if self.lo == self.hi:
            message = (
                "Search complete. First bad version: %s"
                " at index: %d"
                % (
                    self.sorted_list[self.current],
                    self.lo,
                )
            )
            self.logger.LogOutput(message)
            return True

        for index in range(self.lo, self.hi):
            if index not in self.skipped_indices:
                return False
        self.logger.LogOutput(
            "All skipped indices between: %d and %d\n" % (self.lo, self.hi),
            print_to_console=verbose,
        )
        return True

    # Does a better job with chromeos flakiness.
    def GetNextFlakyBinary(self):
        t = (self.lo, self.current, self.hi)
        q = [t]
        while q:
            element = q.pop(0)
            if element[1] in self.skipped_indices:
                # Go top
                to_add = (
                    element[0],
                    (element[0] + element[1]) // 2,
                    element[1],
                )
                q.append(to_add)
                # Go bottom
                to_add = (
                    element[1],
                    (element[1] + element[2]) // 2,
                    element[2],
                )
                q.append(to_add)
            else:
                self.current = element[1]
                return
        assert q, "Queue should never be 0-size!"

    def GetNextFlakyLinear(self):
        current_hi = self.current
        current_lo = self.current
        while True:
            if current_hi < self.hi and current_hi not in self.skipped_indices:
                self.current = current_hi
                break
            if current_lo >= self.lo and current_lo not in self.skipped_indices:
                self.current = current_lo
                break
            if current_lo < self.lo and current_hi >= self.hi:
                break

            current_hi += 1
            current_lo -= 1

    def GetNext(self):
        self.current = (self.hi + self.lo) // 2
        # Try going forward if current is skipped.
        if self.current in self.skipped_indices:
            self.GetNextFlakyBinary()

        # TODO: Add an estimated time remaining as well.
        message = "Estimated tries: min: %d max: %d\n" % (
            1 + math.log(self.hi - self.lo, 2),
            self.hi - self.lo - len(self.skipped_indices),
        )
        self.logger.LogOutput(message, print_to_console=verbose)
        message = "lo: %d hi: %d current: %d version: %s\n" % (
            self.lo,
            self.hi,
            self.current,
            self.sorted_list[self.current],
        )
        self.logger.LogOutput(message, print_to_console=verbose)
        self.logger.LogOutput(str(self), print_to_console=verbose)
        return self.sorted_list[self.current]

    def SetLoRevision(self, lo_revision):
        self.lo = self.sorted_list.index(lo_revision)

    def SetHiRevision(self, hi_revision):
        self.hi = self.sorted_list.index(hi_revision)

    def GetAllPoints(self):
        to_return = ""
        for i in range(len(self.sorted_list)):
            to_return += "%d %d %s\n" % (
                self.points[i].status,
                i,
                self.points[i].revision,
            )

        return to_return

    def __str__(self):
        to_return = ""
        to_return += "Current: %d\n" % self.current
        to_return += str(self.index_log) + "\n"
        revision_log = []
        for index in self.index_log:
            revision_log.append(self.sorted_list[index])
        to_return += str(revision_log) + "\n"
        to_return += str(self.status_log) + "\n"
        to_return += "Skipped indices:\n"
        to_return += str(self.skipped_indices) + "\n"
        to_return += self.GetAllPoints()
        return to_return


class RevisionInfo(object):
    """Class of reversion info."""

    def __init__(self, date, client, description):
        self.date = date
        self.client = client
        self.description = description
        self.status = -1


class VCSBinarySearcher(object):
    """Class of VCS binary searcher."""

    def __init__(self):
        self.bs = BinarySearcher()
        self.rim = {}
        self.current_ce = None
        self.checkout_dir = None
        self.current_revision = None

    def Initialize(self):
        pass

    def GetNextRevision(self):
        pass

    def CheckoutRevision(self, current_revision):
        pass

    def SetStatus(self, status):
        pass

    def Cleanup(self):
        pass

    def SetGoodRevision(self, revision):
        if revision is None:
            return
        assert revision in self.bs.sorted_list
        self.bs.SetLoRevision(revision)

    def SetBadRevision(self, revision):
        if revision is None:
            return
        assert revision in self.bs.sorted_list
        self.bs.SetHiRevision(revision)


class P4BinarySearcher(VCSBinarySearcher):
    """Class of P4 binary searcher."""

    def __init__(self, p4_port, p4_paths, test_command):
        VCSBinarySearcher.__init__(self)
        self.p4_port = p4_port
        self.p4_paths = p4_paths
        self.test_command = test_command
        self.checkout_dir = tempfile.mkdtemp()
        self.ce = command_executer.GetCommandExecuter()
        self.client_name = "binary-searcher-$HOSTNAME-$USER"
        self.job_log_root = "/home/asharif/www/coreboot_triage/"
        self.changes = None

    def Initialize(self):
        self.Cleanup()
        command = GetP4Command(
            self.client_name, self.p4_port, self.p4_paths, 1, self.checkout_dir
        )
        self.ce.RunCommand(command)
        command = "cd %s && g4 changes ..." % self.checkout_dir
        _, out, _ = self.ce.RunCommandWOutput(command)
        self.changes = re.findall(r"Change (\d+)", out)
        change_infos = re.findall(
            r"Change (\d+) on ([\d/]+) by " r"([^\s]+) ('[^']*')", out
        )
        for change_info in change_infos:
            ri = RevisionInfo(change_info[1], change_info[2], change_info[3])
            self.rim[change_info[0]] = ri
        # g4 gives changes in reverse chronological order.
        self.changes.reverse()
        self.bs.SetSortedList(self.changes)

    def SetStatus(self, status):
        self.rim[self.current_revision].status = status
        return self.bs.SetStatus(status)

    def GetNextRevision(self):
        next_revision = self.bs.GetNext()
        self.current_revision = next_revision
        return next_revision

    def CleanupCLs(self):
        if not os.path.isfile(self.checkout_dir + "/.p4config"):
            command = "cd %s" % self.checkout_dir
            command += " && cp ${HOME}/.p4config ."
            command += ' && echo "P4PORT=' + self.p4_port + '" >> .p4config'
            command += (
                ' && echo "P4CLIENT=' + self.client_name + '" >> .p4config'
            )
            self.ce.RunCommand(command)
        command = "cd %s" % self.checkout_dir
        command += "; g4 changes -c %s" % self.client_name
        _, out, _ = self.ce.RunCommandWOutput(command)
        changes = re.findall(r"Change (\d+)", out)
        if changes:
            command = "cd %s" % self.checkout_dir
            for change in changes:
                command += "; g4 revert -c %s" % change
            self.ce.RunCommand(command)

    def CleanupClient(self):
        command = "cd %s" % self.checkout_dir
        command += "; g4 revert ..."
        command += "; g4 client -d %s" % self.client_name
        self.ce.RunCommand(command)

    def Cleanup(self):
        self.CleanupCLs()
        self.CleanupClient()

    def __str__(self):
        to_return = ""
        for change in self.changes:
            ri = self.rim[change]
            if ri.status == -1:
                to_return = "%s\t%d\n" % (change, ri.status)
            else:
                to_return += "%s\t%d\t%s\t%s\t%s\t%s\t%s\t%s\n" % (
                    change,
                    ri.status,
                    ri.date,
                    ri.client,
                    ri.description,
                    self.job_log_root + change + ".cmd",
                    self.job_log_root + change + ".out",
                    self.job_log_root + change + ".err",
                )
        return to_return


class P4GCCBinarySearcher(P4BinarySearcher):
    """Class of P4 gcc binary searcher."""

    # TODO: eventually get these patches from g4 instead of creating them manually
    def HandleBrokenCLs(self, current_revision):
        cr = int(current_revision)
        problematic_ranges = []
        problematic_ranges.append([44528, 44539])
        problematic_ranges.append([44528, 44760])
        problematic_ranges.append([44335, 44882])
        command = "pwd"
        for pr in problematic_ranges:
            if cr in range(pr[0], pr[1]):
                patch_file = "/home/asharif/triage_tool/%d-%d.patch" % (
                    pr[0],
                    pr[1],
                )
                with open(patch_file, encoding="utf-8") as f:
                    patch = f.read()
                files = re.findall("--- (//.*)", patch)
                command += "; cd %s" % self.checkout_dir
                for f in files:
                    command += "; g4 open %s" % f
                command += "; patch -p2 < %s" % patch_file
        self.current_ce.RunCommand(command)

    def CheckoutRevision(self, current_revision):
        job_logger = logger.Logger(
            self.job_log_root, current_revision, True, subdir=""
        )
        self.current_ce = command_executer.GetCommandExecuter(job_logger)

        self.CleanupCLs()
        # Change the revision of only the gcc part of the toolchain.
        command = (
            "cd %s/gcctools/google_vendor_src_branch/gcc "
            "&& g4 revert ...; g4 sync @%s"
            % (self.checkout_dir, current_revision)
        )
        self.current_ce.RunCommand(command)

        self.HandleBrokenCLs(current_revision)


def Main(argv):
    """The main function."""
    # Common initializations
    ###  command_executer.InitCommandExecuter(True)
    ce = command_executer.GetCommandExecuter()

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-n",
        "--num_tries",
        dest="num_tries",
        default="100",
        help="Number of tries.",
    )
    parser.add_argument(
        "-g",
        "--good_revision",
        dest="good_revision",
        help="Last known good revision.",
    )
    parser.add_argument(
        "-b",
        "--bad_revision",
        dest="bad_revision",
        help="Last known bad revision.",
    )
    parser.add_argument(
        "-s", "--script", dest="script", help="Script to run for every version."
    )
    options = parser.parse_args(argv)
    # First get all revisions
    p4_paths = [
        "//depot2/gcctools/google_vendor_src_branch/gcc/gcc-4.4.3/...",
        "//depot2/gcctools/google_vendor_src_branch/binutils/"
        "binutils-2.20.1-mobile/...",
        "//depot2/gcctools/google_vendor_src_branch/"
        "binutils/binutils-20100303/...",
    ]
    p4gccbs = P4GCCBinarySearcher("perforce2:2666", p4_paths, "")

    # Main loop:
    terminated = False
    num_tries = int(options.num_tries)
    script = os.path.expanduser(options.script)

    try:
        p4gccbs.Initialize()
        p4gccbs.SetGoodRevision(options.good_revision)
        p4gccbs.SetBadRevision(options.bad_revision)
        while not terminated and num_tries > 0:
            current_revision = p4gccbs.GetNextRevision()

            # Now run command to get the status
            ce = command_executer.GetCommandExecuter()
            command = "%s %s" % (script, p4gccbs.checkout_dir)
            status = ce.RunCommand(command)
            message = "Revision: %s produced: %d status\n" % (
                current_revision,
                status,
            )
            logger.GetLogger().LogOutput(message, print_to_console=verbose)
            terminated = p4gccbs.SetStatus(status)
            num_tries -= 1
            logger.GetLogger().LogOutput(str(p4gccbs), print_to_console=verbose)

        if not terminated:
            logger.GetLogger().LogOutput(
                "Tries: %d expired." % num_tries, print_to_console=verbose
            )
        logger.GetLogger().LogOutput(str(p4gccbs.bs), print_to_console=verbose)
    except (KeyboardInterrupt, SystemExit):
        logger.GetLogger().LogOutput("Cleaning up...")
    finally:
        logger.GetLogger().LogOutput(str(p4gccbs.bs), print_to_console=verbose)
        p4gccbs.Cleanup()


if __name__ == "__main__":
    Main(sys.argv[1:])
