#!/usr/bin/env python3
# Much of the logging code here was forked from https://github.com/ezyang/ghstack
# Copyright (c) Edward Z. Yang <ezyang@mit.edu>
"""Checks out the nightly development version of PyTorch and installs pre-built
binaries into the repo.

You can use this script to check out a new nightly branch with the following::

    $ ./tools/nightly.py checkout -b my-nightly-branch
    $ conda activate pytorch-deps

Or if you would like to re-use an existing conda environment, you can pass in
the regular environment parameters (--name or --prefix)::

    $ ./tools/nightly.py checkout -b my-nightly-branch -n my-env
    $ conda activate my-env

To install the nightly binaries built with CUDA, you can pass in the flag --cuda::

    $ ./tools/nightly.py checkout -b my-nightly-branch --cuda
    $ conda activate pytorch-deps

You can also use this tool to pull the nightly commits into the current branch as
well. This can be done with::

    $ ./tools/nightly.py pull -n my-env
    $ conda activate my-env

Pulling will reinstall the conda dependencies as well as the nightly binaries into
the repo directory.
"""

from __future__ import annotations

import argparse
import contextlib
import functools
import glob
import itertools
import json
import logging
import os
import re
import shutil
import subprocess
import sys
import tempfile
import time
import uuid
from ast import literal_eval
from datetime import datetime
from pathlib import Path
from platform import system as platform_system
from typing import Any, Callable, cast, Generator, Iterable, Iterator, Sequence, TypeVar


REPO_ROOT = Path(__file__).absolute().parent.parent
GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git"
SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx")

LOGGER: logging.Logger | None = None
URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2"
DATETIME_FORMAT = "%Y-%m-%d_%Hh%Mm%Ss"
SHA1_RE = re.compile(r"(?P<sha1>[0-9a-fA-F]{40})")
USERNAME_PASSWORD_RE = re.compile(r":\/\/(.*?)\@")
LOG_DIRNAME_RE = re.compile(
    r"(?P<datetime>\d{4}-\d\d-\d\d_\d\dh\d\dm\d\ds)_"
    r"(?P<uuid>[0-9a-f]{8}-(?:[0-9a-f]{4}-){3}[0-9a-f]{12})",
)


class Formatter(logging.Formatter):
    redactions: dict[str, str]

    def __init__(self, fmt: str | None = None, datefmt: str | None = None) -> None:
        super().__init__(fmt, datefmt)
        self.redactions = {}

    # Remove sensitive information from URLs
    def _filter(self, s: str) -> str:
        s = USERNAME_PASSWORD_RE.sub(r"://<USERNAME>:<PASSWORD>@", s)
        for needle, replace in self.redactions.items():
            s = s.replace(needle, replace)
        return s

    def formatMessage(self, record: logging.LogRecord) -> str:
        if record.levelno == logging.INFO or record.levelno == logging.DEBUG:
            # Log INFO/DEBUG without any adornment
            return record.getMessage()
        else:
            # I'm not sure why, but formatMessage doesn't show up
            # even though it's in the typeshed for Python >3
            return super().formatMessage(record)

    def format(self, record: logging.LogRecord) -> str:
        return self._filter(super().format(record))

    def redact(self, needle: str, replace: str = "<REDACTED>") -> None:
        """Redact specific strings; e.g., authorization tokens.  This won't
        retroactively redact stuff you've already leaked, so make sure
        you redact things as soon as possible.
        """
        # Don't redact empty strings; this will lead to something
        # that looks like s<REDACTED>t<REDACTED>r<REDACTED>...
        if needle == "":
            return
        self.redactions[needle] = replace


def git(*args: str) -> list[str]:
    return ["git", "-C", str(REPO_ROOT), *args]


@functools.lru_cache
def logging_base_dir() -> Path:
    base_dir = REPO_ROOT / "nightly" / "log"
    base_dir.mkdir(parents=True, exist_ok=True)
    return base_dir


@functools.lru_cache
def logging_run_dir() -> Path:
    base_dir = logging_base_dir()
    cur_dir = base_dir / f"{datetime.now().strftime(DATETIME_FORMAT)}_{uuid.uuid1()}"
    cur_dir.mkdir(parents=True, exist_ok=True)
    return cur_dir


@functools.lru_cache
def logging_record_argv() -> None:
    s = subprocess.list2cmdline(sys.argv)
    (logging_run_dir() / "argv").write_text(s, encoding="utf-8")


def logging_record_exception(e: BaseException) -> None:
    (logging_run_dir() / "exception").write_text(type(e).__name__, encoding="utf-8")


def logging_rotate() -> None:
    log_base = logging_base_dir()
    old_logs = sorted(log_base.iterdir(), reverse=True)
    for stale_log in old_logs[1000:]:
        # Sanity check that it looks like a log
        if LOG_DIRNAME_RE.fullmatch(stale_log.name) is not None:
            shutil.rmtree(stale_log)


@contextlib.contextmanager
def logging_manager(*, debug: bool = False) -> Generator[logging.Logger, None, None]:
    """Setup logging. If a failure starts here we won't
    be able to save the user in a reasonable way.

    Logging structure: there is one logger (the root logger)
    and in processes all events.  There are two handlers:
    stderr (INFO) and file handler (DEBUG).
    """
    formatter = Formatter(fmt="%(levelname)s: %(message)s", datefmt="")
    root_logger = logging.getLogger("conda-pytorch")
    root_logger.setLevel(logging.DEBUG)

    console_handler = logging.StreamHandler()
    if debug:
        console_handler.setLevel(logging.DEBUG)
    else:
        console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    root_logger.addHandler(console_handler)

    log_file = logging_run_dir() / "nightly.log"

    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    root_logger.addHandler(file_handler)
    logging_record_argv()

    try:
        logging_rotate()
        print(f"log file: {log_file}")
        yield root_logger
    except Exception as e:
        logging.exception("Fatal exception")
        logging_record_exception(e)
        print(f"log file: {log_file}")
        sys.exit(1)
    except BaseException as e:
        # You could logging.debug here to suppress the backtrace
        # entirely, but there is no reason to hide it from technically
        # savvy users.
        logging.info("", exc_info=True)
        logging_record_exception(e)
        print(f"log file: {log_file}")
        sys.exit(1)


def check_branch(subcommand: str, branch: str | None) -> str | None:
    """Checks that the branch name can be checked out."""
    if subcommand != "checkout":
        return None
    # first make sure actual branch name was given
    if branch is None:
        return "Branch name to checkout must be supplied with '-b' option"
    # next check that the local repo is clean
    cmd = git("status", "--untracked-files=no", "--porcelain")
    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
    if stdout.strip():
        return "Need to have clean working tree to checkout!\n\n" + stdout
    # next check that the branch name doesn't already exist
    cmd = git("show-ref", "--verify", "--quiet", f"refs/heads/{branch}")
    p = subprocess.run(cmd, capture_output=True, check=False)  # type: ignore[assignment]
    if not p.returncode:
        return f"Branch {branch!r} already exists"
    return None


@contextlib.contextmanager
def timer(logger: logging.Logger, prefix: str) -> Iterator[None]:
    """Timed context manager"""
    start_time = time.perf_counter()
    yield
    logger.info("%s took %.3f [s]", prefix, time.perf_counter() - start_time)


F = TypeVar("F", bound=Callable[..., Any])


def timed(prefix: str) -> Callable[[F], F]:
    """Decorator for timing functions"""

    def dec(f: F) -> F:
        @functools.wraps(f)
        def wrapper(*args: Any, **kwargs: Any) -> Any:
            logger = cast(logging.Logger, LOGGER)
            logger.info(prefix)
            with timer(logger, prefix):
                return f(*args, **kwargs)

        return cast(F, wrapper)

    return dec


def _make_channel_args(
    channels: Iterable[str] = ("pytorch-nightly",),
    override_channels: bool = False,
) -> list[str]:
    args = []
    for channel in channels:
        args.extend(["--channel", channel])
    if override_channels:
        args.append("--override-channels")
    return args


@timed("Solving conda environment")
def conda_solve(
    specs: Iterable[str],
    *,
    name: str | None = None,
    prefix: str | None = None,
    channels: Iterable[str] = ("pytorch-nightly",),
    override_channels: bool = False,
) -> tuple[list[str], str, str, bool, list[str]]:
    """Performs the conda solve and splits the deps from the package."""
    # compute what environment to use
    if prefix is not None:
        existing_env = True
        env_opts = ["--prefix", prefix]
    elif name is not None:
        existing_env = True
        env_opts = ["--name", name]
    else:
        # create new environment
        existing_env = False
        env_opts = ["--name", "pytorch-deps"]
    # run solve
    if existing_env:
        cmd = [
            "conda",
            "install",
            "--yes",
            "--dry-run",
            "--json",
        ]
        cmd.extend(env_opts)
    else:
        cmd = [
            "conda",
            "create",
            "--yes",
            "--dry-run",
            "--json",
            "--name",
            "__pytorch__",
        ]
    channel_args = _make_channel_args(
        channels=channels,
        override_channels=override_channels,
    )
    cmd.extend(channel_args)
    cmd.extend(specs)
    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
    # parse solution
    solve = json.loads(stdout)
    link = solve["actions"]["LINK"]
    deps = []
    pytorch, platform = "", ""
    for pkg in link:
        url = URL_FORMAT.format(**pkg)
        if pkg["name"] == "pytorch":
            pytorch = url
            platform = pkg["platform"]
        else:
            deps.append(url)
    assert pytorch, "PyTorch package not found in solve"
    assert platform, "Platform not found in solve"
    return deps, pytorch, platform, existing_env, env_opts


@timed("Installing dependencies")
def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> None:
    """Install dependencies to deps environment"""
    if not existing_env:
        # first remove previous pytorch-deps env
        cmd = ["conda", "env", "remove", "--yes", *env_opts]
        subprocess.check_call(cmd)
    # install new deps
    install_command = "install" if existing_env else "create"
    cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps]
    subprocess.check_call(cmd)


@timed("Installing pytorch nightly binaries")
def pytorch_install(url: str) -> tempfile.TemporaryDirectory[str]:
    """Install pytorch into a temporary directory"""
    pytorch_dir = tempfile.TemporaryDirectory(prefix="conda-pytorch-")
    cmd = ["conda", "create", "--yes", "--no-deps", f"--prefix={pytorch_dir.name}", url]
    subprocess.check_call(cmd)
    return pytorch_dir


def _site_packages(dirname: str, platform: str) -> Path:
    if platform.startswith("win"):
        template = os.path.join(dirname, "Lib", "site-packages")
    else:
        template = os.path.join(dirname, "lib", "python*.*", "site-packages")
    return Path(next(glob.iglob(template))).absolute()


def _ensure_commit(git_sha1: str) -> None:
    """Make sure that we actually have the commit locally"""
    cmd = git("cat-file", "-e", git_sha1 + r"^{commit}")
    p = subprocess.run(cmd, capture_output=True, check=False)
    if p.returncode == 0:
        # we have the commit locally
        return
    # we don't have the commit, must fetch
    cmd = git("fetch", GITHUB_REMOTE_URL, git_sha1)
    subprocess.check_call(cmd)


def _nightly_version(site_dir: Path) -> str:
    # first get the git version from the installed module
    version_file = site_dir / "torch" / "version.py"
    with version_file.open(encoding="utf-8") as f:
        for line in f:
            if not line.startswith("git_version"):
                continue
            git_version = literal_eval(line.partition("=")[2].strip())
            break
        else:
            raise RuntimeError(f"Could not find git_version in {version_file}")

    print(f"Found released git version {git_version}")
    # now cross reference with nightly version
    _ensure_commit(git_version)
    cmd = git("show", "--no-patch", "--format=%s", git_version)
    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
    m = SHA1_RE.search(stdout)
    if m is None:
        raise RuntimeError(
            f"Could not find nightly release in git history:\n  {stdout}"
        )
    nightly_version = m.group("sha1")
    print(f"Found nightly release version {nightly_version}")
    # now checkout nightly version
    _ensure_commit(nightly_version)
    return nightly_version


@timed("Checking out nightly PyTorch")
def checkout_nightly_version(branch: str, site_dir: Path) -> None:
    """Get's the nightly version and then checks it out."""
    nightly_version = _nightly_version(site_dir)
    cmd = git("checkout", "-b", branch, nightly_version)
    subprocess.check_call(cmd)


@timed("Pulling nightly PyTorch")
def pull_nightly_version(site_dir: Path) -> None:
    """Fetches the nightly version and then merges it ."""
    nightly_version = _nightly_version(site_dir)
    cmd = git("merge", nightly_version)
    subprocess.check_call(cmd)


def _get_listing_linux(source_dir: Path) -> list[Path]:
    return list(
        itertools.chain(
            source_dir.glob("*.so"),
            (source_dir / "lib").glob("*.so"),
            (source_dir / "lib").glob("*.so.*"),
        )
    )


def _get_listing_osx(source_dir: Path) -> list[Path]:
    # oddly, these are .so files even on Mac
    return list(
        itertools.chain(
            source_dir.glob("*.so"),
            (source_dir / "lib").glob("*.dylib"),
        )
    )


def _get_listing_win(source_dir: Path) -> list[Path]:
    return list(
        itertools.chain(
            source_dir.glob("*.pyd"),
            (source_dir / "lib").glob("*.lib"),
            (source_dir / "lib").glob(".dll"),
        )
    )


def _glob_pyis(d: Path) -> set[str]:
    return {p.relative_to(d).as_posix() for p in d.rglob("*.pyi")}


def _find_missing_pyi(source_dir: Path, target_dir: Path) -> list[Path]:
    source_pyis = _glob_pyis(source_dir)
    target_pyis = _glob_pyis(target_dir)
    missing_pyis = sorted(source_dir / p for p in (source_pyis - target_pyis))
    return missing_pyis


def _get_listing(source_dir: Path, target_dir: Path, platform: str) -> list[Path]:
    if platform.startswith("linux"):
        listing = _get_listing_linux(source_dir)
    elif platform.startswith("osx"):
        listing = _get_listing_osx(source_dir)
    elif platform.startswith("win"):
        listing = _get_listing_win(source_dir)
    else:
        raise RuntimeError(f"Platform {platform!r} not recognized")
    listing.extend(_find_missing_pyi(source_dir, target_dir))
    listing.append(source_dir / "version.py")
    listing.append(source_dir / "testing" / "_internal" / "generated")
    listing.append(source_dir / "bin")
    listing.append(source_dir / "include")
    return listing


def _remove_existing(path: Path) -> None:
    if path.exists():
        if path.is_dir():
            shutil.rmtree(path)
        else:
            path.unlink()


def _move_single(
    src: Path,
    source_dir: Path,
    target_dir: Path,
    mover: Callable[[Path, Path], None],
    verb: str,
) -> None:
    relpath = src.relative_to(source_dir)
    trg = target_dir / relpath
    _remove_existing(trg)
    # move over new files
    if src.is_dir():
        trg.mkdir(parents=True, exist_ok=True)
        for root, dirs, files in os.walk(src):
            relroot = Path(root).relative_to(src)
            for name in files:
                relname = relroot / name
                s = src / relname
                t = trg / relname
                print(f"{verb} {s} -> {t}")
                mover(s, t)
            for name in dirs:
                (trg / relroot / name).mkdir(parents=True, exist_ok=True)
    else:
        print(f"{verb} {src} -> {trg}")
        mover(src, trg)


def _copy_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
    for src in listing:
        _move_single(src, source_dir, target_dir, shutil.copy2, "Copying")


def _link_files(listing: list[Path], source_dir: Path, target_dir: Path) -> None:
    for src in listing:
        _move_single(src, source_dir, target_dir, os.link, "Linking")


@timed("Moving nightly files into repo")
def move_nightly_files(site_dir: Path, platform: str) -> None:
    """Moves PyTorch files from temporary installed location to repo."""
    # get file listing
    source_dir = site_dir / "torch"
    target_dir = REPO_ROOT / "torch"
    listing = _get_listing(source_dir, target_dir, platform)
    # copy / link files
    if platform.startswith("win"):
        _copy_files(listing, source_dir, target_dir)
    else:
        try:
            _link_files(listing, source_dir, target_dir)
        except Exception:
            _copy_files(listing, source_dir, target_dir)


def _available_envs() -> dict[str, str]:
    cmd = ["conda", "env", "list"]
    stdout = subprocess.check_output(cmd, text=True, encoding="utf-8")
    envs = {}
    for line in map(str.strip, stdout.splitlines()):
        if not line or line.startswith("#"):
            continue
        parts = line.split()
        if len(parts) == 1:
            # unnamed env
            continue
        envs[parts[0]] = parts[-1]
    return envs


@timed("Writing pytorch-nightly.pth")
def write_pth(env_opts: list[str], platform: str) -> None:
    """Writes Python path file for this dir."""
    env_type, env_dir = env_opts
    if env_type == "--name":
        # have to find directory
        envs = _available_envs()
        env_dir = envs[env_dir]
    site_dir = _site_packages(env_dir, platform)
    (site_dir / "pytorch-nightly.pth").write_text(
        "# This file was autogenerated by PyTorch's tools/nightly.py\n"
        "# Please delete this file if you no longer need the following development\n"
        "# version of PyTorch to be importable\n"
        f"{REPO_ROOT}\n",
        encoding="utf-8",
    )


def install(
    specs: Iterable[str],
    *,
    logger: logging.Logger,
    subcommand: str = "checkout",
    branch: str | None = None,
    name: str | None = None,
    prefix: str | None = None,
    channels: Iterable[str] = ("pytorch-nightly",),
    override_channels: bool = False,
) -> None:
    """Development install of PyTorch"""
    specs = list(specs)
    deps, pytorch, platform, existing_env, env_opts = conda_solve(
        specs=specs,
        name=name,
        prefix=prefix,
        channels=channels,
        override_channels=override_channels,
    )
    if deps:
        deps_install(deps, existing_env, env_opts)

    with pytorch_install(pytorch) as pytorch_dir:
        site_dir = _site_packages(pytorch_dir, platform)
        if subcommand == "checkout":
            checkout_nightly_version(cast(str, branch), site_dir)
        elif subcommand == "pull":
            pull_nightly_version(site_dir)
        else:
            raise ValueError(f"Subcommand {subcommand} must be one of: checkout, pull.")
        move_nightly_files(site_dir, platform)

    write_pth(env_opts, platform)
    logger.info(
        "-------\nPyTorch Development Environment set up!\nPlease activate to "
        "enable this environment:\n  $ conda activate %s",
        env_opts[1],
    )


def make_parser() -> argparse.ArgumentParser:
    p = argparse.ArgumentParser()
    # subcommands
    subcmd = p.add_subparsers(dest="subcmd", help="subcommand to execute")
    checkout = subcmd.add_parser("checkout", help="checkout a new branch")
    checkout.add_argument(
        "-b",
        "--branch",
        help="Branch name to checkout",
        dest="branch",
        default=None,
        metavar="NAME",
    )
    pull = subcmd.add_parser(
        "pull", help="pulls the nightly commits into the current branch"
    )
    # general arguments
    subparsers = [checkout, pull]
    for subparser in subparsers:
        subparser.add_argument(
            "-n",
            "--name",
            help="Name of environment",
            dest="name",
            default=None,
            metavar="ENVIRONMENT",
        )
        subparser.add_argument(
            "-p",
            "--prefix",
            help="Full path to environment location (i.e. prefix)",
            dest="prefix",
            default=None,
            metavar="PATH",
        )
        subparser.add_argument(
            "-v",
            "--verbose",
            help="Provide debugging info",
            dest="verbose",
            default=False,
            action="store_true",
        )
        subparser.add_argument(
            "--override-channels",
            help="Do not search default or .condarc channels.",
            dest="override_channels",
            default=False,
            action="store_true",
        )
        subparser.add_argument(
            "-c",
            "--channel",
            help=(
                "Additional channel to search for packages. "
                "'pytorch-nightly' will always be prepended to this list."
            ),
            dest="channels",
            action="append",
            metavar="CHANNEL",
        )
        if platform_system() in {"Linux", "Windows"}:
            subparser.add_argument(
                "--cuda",
                help=(
                    "CUDA version to install "
                    "(defaults to the latest version available on the platform)"
                ),
                dest="cuda",
                nargs="?",
                default=argparse.SUPPRESS,
                metavar="VERSION",
            )
    return p


def main(args: Sequence[str] | None = None) -> None:
    """Main entry point"""
    global LOGGER
    p = make_parser()
    ns = p.parse_args(args)
    ns.branch = getattr(ns, "branch", None)
    status = check_branch(ns.subcmd, ns.branch)
    if status:
        sys.exit(status)
    specs = list(SPECS_TO_INSTALL)
    channels = ["pytorch-nightly"]
    if hasattr(ns, "cuda"):
        if ns.cuda is not None:
            specs.append(f"pytorch-cuda={ns.cuda}")
        else:
            specs.append("pytorch-cuda")
        specs.append("pytorch-mutex=*=*cuda*")
        channels.append("nvidia")
    else:
        specs.append("pytorch-mutex=*=*cpu*")
    if ns.channels:
        channels.extend(ns.channels)
    with logging_manager(debug=ns.verbose) as logger:
        LOGGER = logger
        install(
            specs=specs,
            subcommand=ns.subcmd,
            branch=ns.branch,
            name=ns.name,
            prefix=ns.prefix,
            logger=logger,
            channels=channels,
            override_channels=ns.override_channels,
        )


if __name__ == "__main__":
    main()
