from __future__ import annotations

import argparse
import json
import os
import xml.etree.ElementTree as ET
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Generator

from tools.stats.upload_stats_lib import (
    download_s3_artifacts,
    is_rerun_disabled_tests,
    unzip,
    upload_workflow_stats_to_s3,
)
from tools.stats.upload_test_stats import process_xml_element


TESTCASE_TAG = "testcase"
SEPARATOR = ";"


def process_report(
    report: Path,
) -> dict[str, dict[str, int]]:
    """
    Return a list of disabled tests that should be re-enabled and those that are still
    flaky (failed or skipped)
    """
    root = ET.parse(report)

    # All rerun tests from a report are grouped here:
    #
    # * Success test should be re-enable if it's green after rerunning in all platforms
    #   where it is currently disabled
    # * Failures from pytest because pytest-flakefinder is used to run the same test
    #   multiple times, some could fails
    # * Skipped tests from unittest
    #
    # We want to keep track of how many times the test fails (num_red) or passes (num_green)
    all_tests: dict[str, dict[str, int]] = {}

    for test_case in root.iter(TESTCASE_TAG):
        parsed_test_case = process_xml_element(test_case)

        # Under --rerun-disabled-tests mode, a test is skipped when:
        # * it's skipped explicitly inside PyTorch code
        # * it's skipped because it's a normal enabled test
        # * or it's falky (num_red > 0 and num_green > 0)
        # * or it's failing (num_red > 0 and num_green == 0)
        #
        # We care only about the latter two here
        skipped = parsed_test_case.get("skipped", None)

        # NB: Regular ONNX tests could return a list of subskips here where each item in the
        # list is a skipped message.  In the context of rerunning disabled tests, we could
        # ignore this case as returning a list of subskips only happens when tests are run
        # normally
        if skipped and (
            type(skipped) is list or "num_red" not in skipped.get("message", "")
        ):
            continue

        name = parsed_test_case.get("name", "")
        classname = parsed_test_case.get("classname", "")
        filename = parsed_test_case.get("file", "")

        if not name or not classname or not filename:
            continue

        # Check if the test is a failure
        failure = parsed_test_case.get("failure", None)

        disabled_test_id = SEPARATOR.join([name, classname, filename])
        if disabled_test_id not in all_tests:
            all_tests[disabled_test_id] = {
                "num_green": 0,
                "num_red": 0,
            }

        # Under --rerun-disabled-tests mode, if a test is not skipped or failed, it's
        # counted as a success. Otherwise, it's still flaky or failing
        if skipped:
            try:
                stats = json.loads(skipped.get("message", ""))
            except json.JSONDecodeError:
                stats = {}

            all_tests[disabled_test_id]["num_green"] += stats.get("num_green", 0)
            all_tests[disabled_test_id]["num_red"] += stats.get("num_red", 0)
        elif failure:
            # As a failure, increase the failure count
            all_tests[disabled_test_id]["num_red"] += 1
        else:
            all_tests[disabled_test_id]["num_green"] += 1

    return all_tests


def get_test_reports(
    repo: str, workflow_run_id: int, workflow_run_attempt: int
) -> Generator[Path, None, None]:
    """
    Gather all the test reports from S3 and GHA. It is currently not possible to guess which
    test reports are from rerun_disabled_tests workflow because the name doesn't include the
    test config. So, all reports will need to be downloaded and examined
    """
    with TemporaryDirectory() as temp_dir:
        print("Using temporary directory:", temp_dir)
        os.chdir(temp_dir)

        artifact_paths = download_s3_artifacts(
            "test-reports", workflow_run_id, workflow_run_attempt
        )
        for path in artifact_paths:
            unzip(path)

        yield from Path(".").glob("**/*.xml")


def get_disabled_test_name(test_id: str) -> tuple[str, str, str, str]:
    """
    Follow flaky bot convention here, if that changes, this will also need to be updated
    """
    name, classname, filename = test_id.split(SEPARATOR)
    return f"{name} (__main__.{classname})", name, classname, filename


def prepare_record(
    workflow_id: int,
    workflow_run_attempt: int,
    name: str,
    classname: str,
    filename: str,
    flaky: bool,
    num_red: int = 0,
    num_green: int = 0,
) -> tuple[Any, dict[str, Any]]:
    """
    Prepare the record to save onto S3
    """
    key = (
        workflow_id,
        workflow_run_attempt,
        name,
        classname,
        filename,
    )

    record = {
        "workflow_id": workflow_id,
        "workflow_run_attempt": workflow_run_attempt,
        "name": name,
        "classname": classname,
        "filename": filename,
        "flaky": flaky,
        "num_green": num_green,
        "num_red": num_red,
    }

    return key, record


def save_results(
    workflow_id: int,
    workflow_run_attempt: int,
    all_tests: dict[str, dict[str, int]],
) -> None:
    """
    Save the result to S3, so it can go to Rockset
    """
    should_be_enabled_tests = {
        name: stats
        for name, stats in all_tests.items()
        if "num_green" in stats
        and stats["num_green"]
        and "num_red" in stats
        and stats["num_red"] == 0
    }
    still_flaky_tests = {
        name: stats
        for name, stats in all_tests.items()
        if name not in should_be_enabled_tests
    }

    records = {}
    for test_id, stats in all_tests.items():
        num_green = stats.get("num_green", 0)
        num_red = stats.get("num_red", 0)
        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)

        key, record = prepare_record(
            workflow_id=workflow_id,
            workflow_run_attempt=workflow_run_attempt,
            name=name,
            classname=classname,
            filename=filename,
            flaky=test_id in still_flaky_tests,
            num_green=num_green,
            num_red=num_red,
        )
        records[key] = record

    # Log the results
    print(f"The following {len(should_be_enabled_tests)} tests should be re-enabled:")
    for test_id, stats in should_be_enabled_tests.items():
        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
        print(f"  {disabled_test_name} from {filename}")

    print(f"The following {len(still_flaky_tests)} are still flaky:")
    for test_id, stats in still_flaky_tests.items():
        num_green = stats.get("num_green", 0)
        num_red = stats.get("num_red", 0)

        disabled_test_name, name, classname, filename = get_disabled_test_name(test_id)
        print(
            f"  {disabled_test_name} from {filename}, failing {num_red}/{num_red + num_green}"
        )

    upload_workflow_stats_to_s3(
        workflow_id,
        workflow_run_attempt,
        "rerun_disabled_tests",
        list(records.values()),
    )


def main(repo: str, workflow_run_id: int, workflow_run_attempt: int) -> None:
    """
    Find the list of all disabled tests that should be re-enabled
    """
    # Aggregated across all jobs
    all_tests: dict[str, dict[str, int]] = {}

    for report in get_test_reports(
        args.repo, args.workflow_run_id, args.workflow_run_attempt
    ):
        tests = process_report(report)

        # The scheduled workflow has both rerun disabled tests and memory leak check jobs.
        # We are only interested in the former here
        if not is_rerun_disabled_tests(tests):
            continue

        for name, stats in tests.items():
            if name not in all_tests:
                all_tests[name] = stats.copy()
            else:
                all_tests[name]["num_green"] += stats.get("num_green", 0)
                all_tests[name]["num_red"] += stats.get("num_red", 0)

    save_results(
        workflow_run_id,
        workflow_run_attempt,
        all_tests,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Upload test artifacts from GHA to S3")
    parser.add_argument(
        "--workflow-run-id",
        type=int,
        required=True,
        help="id of the workflow to get artifacts from",
    )
    parser.add_argument(
        "--workflow-run-attempt",
        type=int,
        required=True,
        help="which retry of the workflow this is",
    )
    parser.add_argument(
        "--repo",
        type=str,
        required=True,
        help="which GitHub repo this workflow run belongs to",
    )

    args = parser.parse_args()
    main(args.repo, args.workflow_run_id, args.workflow_run_attempt)
