import argparse
import hashlib
import json
import logging
import os
import platform
import stat
import subprocess
import sys
import urllib.error
import urllib.request
from pathlib import Path


# String representing the host platform (e.g. Linux, Darwin).
HOST_PLATFORM = platform.system()
HOST_PLATFORM_ARCH = platform.system() + "-" + platform.processor()

# PyTorch directory root
try:
    result = subprocess.run(
        ["git", "rev-parse", "--show-toplevel"],
        stdout=subprocess.PIPE,
        check=True,
    )
    PYTORCH_ROOT = result.stdout.decode("utf-8").strip()
except subprocess.CalledProcessError:
    # If git is not installed, compute repo root as 3 folders up from this file
    path_ = os.path.abspath(__file__)
    for _ in range(4):
        path_ = os.path.dirname(path_)
    PYTORCH_ROOT = path_

DRY_RUN = False


def compute_file_sha256(path: str) -> str:
    """Compute the SHA256 hash of a file and return it as a hex string."""
    # If the file doesn't exist, return an empty string.
    if not os.path.exists(path):
        return ""

    hash = hashlib.sha256()

    # Open the file in binary mode and hash it.
    with open(path, "rb") as f:
        for b in f:
            hash.update(b)

    # Return the hash as a hexadecimal string.
    return hash.hexdigest()


def report_download_progress(
    chunk_number: int, chunk_size: int, file_size: int
) -> None:
    """
    Pretty printer for file download progress.
    """
    if file_size != -1:
        percent = min(1, (chunk_number * chunk_size) / file_size)
        bar = "#" * int(64 * percent)
        sys.stdout.write(f"\r0% |{bar:<64}| {int(percent * 100)}%")


def check(binary_path: Path, reference_hash: str) -> bool:
    """Check whether the binary exists and is the right one.

    If there is hash difference, delete the actual binary.
    """
    if not binary_path.exists():
        logging.info("%s does not exist.", binary_path)
        return False

    existing_binary_hash = compute_file_sha256(str(binary_path))
    if existing_binary_hash == reference_hash:
        return True

    logging.warning(
        """\
Found binary hash does not match reference!

Found hash: %s
Reference hash: %s

Deleting %s just to be safe.
""",
        existing_binary_hash,
        reference_hash,
        binary_path,
    )
    if DRY_RUN:
        logging.critical(
            "In dry run mode, so not actually deleting the binary. But consider deleting it ASAP!"
        )
        return False

    try:
        binary_path.unlink()
    except OSError as e:
        logging.critical("Failed to delete binary: %s", e)
        logging.critical(
            "Delete this binary as soon as possible and do not execute it!"
        )

    return False


def download(
    name: str,
    output_dir: str,
    url: str,
    reference_bin_hash: str,
) -> bool:
    """
    Download a platform-appropriate binary if one doesn't already exist at the expected location and verifies
    that it is the right binary by checking its SHA256 hash against the expected hash.
    """
    # First check if we need to do anything
    binary_path = Path(output_dir, name)
    if check(binary_path, reference_bin_hash):
        logging.info("Correct binary already exists at %s. Exiting.", binary_path)
        return True

    # Create the output folder
    binary_path.parent.mkdir(parents=True, exist_ok=True)

    # Download the binary
    logging.info("Downloading %s to %s", url, binary_path)

    if DRY_RUN:
        logging.info("Exiting as there is nothing left to do in dry run mode")
        return True

    urllib.request.urlretrieve(
        url,
        binary_path,
        reporthook=report_download_progress if sys.stdout.isatty() else None,
    )

    logging.info("Downloaded %s successfully.", name)

    # Check the downloaded binary
    if not check(binary_path, reference_bin_hash):
        logging.critical("Downloaded binary %s failed its hash check", name)
        return False

    # Ensure that executable bits are set
    mode = os.stat(binary_path).st_mode
    mode |= stat.S_IXUSR
    os.chmod(binary_path, mode)

    logging.info("Using %s located at %s", name, binary_path)
    return True


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="downloads and checks binaries from s3",
    )
    parser.add_argument(
        "--config-json",
        required=True,
        help="Path to config json that describes where to find binaries and hashes",
    )
    parser.add_argument(
        "--linter",
        required=True,
        help="Which linter to initialize from the config json",
    )
    parser.add_argument(
        "--output-dir",
        required=True,
        help="place to put the binary",
    )
    parser.add_argument(
        "--output-name",
        required=True,
        help="name of binary",
    )
    parser.add_argument(
        "--dry-run",
        default=False,
        help="do not download, just print what would be done",
    )

    args = parser.parse_args()
    if args.dry_run == "0":
        DRY_RUN = False
    else:
        DRY_RUN = True

    logging.basicConfig(
        format="[DRY_RUN] %(levelname)s: %(message)s"
        if DRY_RUN
        else "%(levelname)s: %(message)s",
        level=logging.INFO,
        stream=sys.stderr,
    )

    config = json.load(open(args.config_json))
    config = config[args.linter]

    # Allow processor specific binaries for platform (namely Intel and M1 binaries for MacOS)
    host_platform = HOST_PLATFORM if HOST_PLATFORM in config else HOST_PLATFORM_ARCH
    # If the host platform is not in platform_to_hash, it is unsupported.
    if host_platform not in config:
        logging.error("Unsupported platform: %s/%s", HOST_PLATFORM, HOST_PLATFORM_ARCH)
        sys.exit(1)

    url = config[host_platform]["download_url"]
    hash = config[host_platform]["hash"]

    ok = download(args.output_name, args.output_dir, url, hash)
    if not ok:
        logging.critical("Unable to initialize %s", args.linter)
        sys.exit(1)
