import argparse
import logging
import os
from functools import partial

import torch
import torch._dynamo as dynamo
import torch.utils._pytree as pytree
from torch._dynamo.testing import reduce_to_scalar_loss
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.profiler import profile, ProfilerActivity, record_function


try:
    from .common import timed
    from .dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup
except ImportError:
    from common import timed
    from dist_util import apply_fsdp, cleanup, get_model, model_iter_fn, setup

log = logging.getLogger(__name__)


def torchviz_model(args, model, inputs, rank):
    from torchviz import make_dot

    outputs = model(*inputs)
    loss = reduce_to_scalar_loss(outputs)
    parameter_names = dict(model.named_parameters())
    dot = make_dot(loss, params=parameter_names, show_attrs=True, show_saved=True)
    if rank == 0:
        dot.render("torchviz.dot")


def profile_model(args, model, inputs, rank):
    with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
        for i in range(args.repeat):
            with record_function("Forward"):
                outputs = model(*inputs)
                loss = reduce_to_scalar_loss(outputs)
            with record_function("Backward"):
                loss.backward()
    if rank == 0:
        prof.export_chrome_trace(args.trace_file)


def run_model(args, model, inputs, key):
    rank = int(os.getenv("RANK", 0))
    world_size = int(os.getenv("WORLD_SIZE", 1))
    # result_q = []

    setup(rank, world_size)
    if args.device == "cuda":
        # needed for FSDP
        torch.cuda.set_device(rank)

    dev_rank = f"{args.device}:{rank}"
    model = model.to(dev_rank)

    def move_tensor(maybe_tensor):
        if torch.is_tensor(maybe_tensor):
            return maybe_tensor.to(dev_rank)
        return maybe_tensor

    inputs = pytree.tree_map(move_tensor, inputs)

    if args.fsdp:
        model = apply_fsdp(
            args,
            model,
            use_checkpointing=args.fsdp_checkpoint,
            use_wrap_policy=args.fsdp_wrap,
        )
    elif args.ddp:
        model = DDP(model)

    if args.verbose:
        print(model)

    if args.dynamo:
        dynamo.reset()
        if args.verbose:
            dynamo.config.verbose = True
            dynamo.config.log_level = logging.DEBUG
        if args.dynamo_no_optimize_ddp:
            dynamo.config.optimize_ddp = False
        if args.dynamo == "inductor" and args.fsdp:
            torch._inductor.config.triton.cudagraphs = False
            log.warning("disabling inductor cudagraphs for compatibility with FSDP")

        def print_compile(gm, ex):
            print(
                f"print_compile:\n{str(gm.graph)}\n-----------------------------------------"
            )
            return gm

        dynamo_ctx = dynamo.optimize(
            print_compile if args.dynamo == "print" else args.dynamo
        )
        model = dynamo_ctx(model)

    # warmup
    _ = timed(model, model_iter_fn, inputs, times=3, return_result=False)
    t_total = timed(
        model, model_iter_fn, inputs, times=args.repeat, return_result=False
    )
    if args.torchviz:
        torchviz_model(args, model, inputs, rank)
    if args.profile:
        profile_model(args, model, inputs, rank)

    cleanup()
    return t_total


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default="cuda")
    parser.add_argument(
        "--dynamo",
        default=None,
        help="if set to a str, uses dynamo[str] backend. else, eager",
    )
    parser.add_argument("--verbose", action="store_true")
    parser.add_argument("--batch-size", "--batch_size", default=None)
    parser.add_argument(
        "--torchviz", action="store_true", help="Dump autograd graph with torchviz"
    )
    parser.add_argument("--profile", action="store_true", help="Run the profiler")
    parser.add_argument(
        "--trace-file", "--trace_file", default="profile.json", help="Run the profiler"
    )
    parser.add_argument("--repeat", default=10, help="Repeats for timing run")
    parser.add_argument(
        "--dynamo-no-optimize-ddp",
        "--dynamo_no_optimize_ddp",
        action="store_true",
        help="Disable dynamo's ddp optimizer (enabled by default)",
    )
    parser.add_argument(
        "--fsdp-checkpoint",
        "--fsdp_checkpoint",
        action="store_true",
        help="Use gradient checkpointing via model-specific policy",
    )
    parser.add_argument(
        "--fsdp-wrap",
        "--fsdp_wrap",
        action="store_true",
        help="Apply fsdp to submodules via model-specific policy",
    )

    dist_arg = parser.add_mutually_exclusive_group()
    dist_arg.add_argument("--ddp", action="store_true")
    dist_arg.add_argument("--fsdp", action="store_true")

    model_arg = parser.add_mutually_exclusive_group(required=True)
    model_arg.add_argument(
        "--torchbench-model",
        "--torchbench_model",
        help="name of torchbench model, e.g. hf_Bert",
    )
    model_arg.add_argument(
        "--toy-model", "--toy_model", action="store_true", help="use toy model instead"
    )
    args = parser.parse_args()

    model_name = args.torchbench_model
    if args.toy_model:
        model_name = "ToyModel"
    model, inputs = get_model(args)

    fn = partial(run_model, args, model, inputs)

    world_size = os.getenv("WORLD_SIZE", 1)
    t_total = fn(f"{model_name}_{world_size}")
    print(f"mean latency {t_total / args.repeat} across {args.repeat} runs")
