#!/usr/bin/env python3

import importlib
import logging
import os
import re
import subprocess
import sys
import warnings


try:
    from .common import BenchmarkRunner, download_retry_decorator, main
except ImportError:
    from common import BenchmarkRunner, download_retry_decorator, main

import torch
from torch._dynamo.testing import collect_results, reduce_to_scalar_loss
from torch._dynamo.utils import clone_inputs


# Enable FX graph caching
if "TORCHINDUCTOR_FX_GRAPH_CACHE" not in os.environ:
    torch._inductor.config.fx_graph_cache = True


def pip_install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])


try:
    importlib.import_module("timm")
except ModuleNotFoundError:
    print("Installing PyTorch Image Models...")
    pip_install("git+https://github.com/rwightman/pytorch-image-models")
finally:
    from timm import __version__ as timmversion
    from timm.data import resolve_data_config
    from timm.models import create_model

TIMM_MODELS = {}
filename = os.path.join(os.path.dirname(__file__), "timm_models_list.txt")

with open(filename) as fh:
    lines = fh.readlines()
    lines = [line.rstrip() for line in lines]
    for line in lines:
        model_name, batch_size = line.split(" ")
        TIMM_MODELS[model_name] = int(batch_size)


# TODO - Figure out the reason of cold start memory spike

BATCH_SIZE_DIVISORS = {
    "beit_base_patch16_224": 2,
    "convit_base": 2,
    "convmixer_768_32": 2,
    "convnext_base": 2,
    "cspdarknet53": 2,
    "deit_base_distilled_patch16_224": 2,
    "gluon_xception65": 2,
    "mobilevit_s": 2,
    "pnasnet5large": 2,
    "poolformer_m36": 2,
    "resnest101e": 2,
    "swin_base_patch4_window7_224": 2,
    "swsl_resnext101_32x16d": 2,
    "vit_base_patch16_224": 2,
    "volo_d1_224": 2,
    "jx_nest_base": 4,
}

REQUIRE_HIGHER_TOLERANCE = {
    "fbnetv3_b",
    "gmixer_24_224",
    "hrnet_w18",
    "inception_v3",
    "mixer_b16_224",
    "mobilenetv3_large_100",
    "sebotnet33ts_256",
    "selecsls42b",
}

REQUIRE_EVEN_HIGHER_TOLERANCE = {
    "levit_128",
    "sebotnet33ts_256",
    "beit_base_patch16_224",
    "cspdarknet53",
}

# These models need higher tolerance in MaxAutotune mode
REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE = {
    "gluon_inception_v3",
}

REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING = {
    "adv_inception_v3",
    "botnet26t_256",
    "gluon_inception_v3",
    "selecsls42b",
    "swsl_resnext101_32x16d",
}

SCALED_COMPUTE_LOSS = {
    "ese_vovnet19b_dw",
    "fbnetc_100",
    "mnasnet_100",
    "mobilevit_s",
    "sebotnet33ts_256",
}

FORCE_AMP_FOR_FP16_BF16_MODELS = {
    "convit_base",
    "xcit_large_24_p8_224",
}

SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS = {
    "xcit_large_24_p8_224",
}

REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR = {
    "inception_v3",
    "mobilenetv3_large_100",
    "cspdarknet53",
}


def refresh_model_names():
    import glob

    from timm.models import list_models

    def read_models_from_docs():
        models = set()
        # TODO - set the path to pytorch-image-models repo
        for fn in glob.glob("../pytorch-image-models/docs/models/*.md"):
            with open(fn) as f:
                while True:
                    line = f.readline()
                    if not line:
                        break
                    if not line.startswith("model = timm.create_model("):
                        continue

                    model = line.split("'")[1]
                    # print(model)
                    models.add(model)
        return models

    def get_family_name(name):
        known_families = [
            "darknet",
            "densenet",
            "dla",
            "dpn",
            "ecaresnet",
            "halo",
            "regnet",
            "efficientnet",
            "deit",
            "mobilevit",
            "mnasnet",
            "convnext",
            "resnet",
            "resnest",
            "resnext",
            "selecsls",
            "vgg",
            "xception",
        ]

        for known_family in known_families:
            if known_family in name:
                return known_family

        if name.startswith("gluon_"):
            return "gluon_" + name.split("_")[1]
        return name.split("_")[0]

    def populate_family(models):
        family = {}
        for model_name in models:
            family_name = get_family_name(model_name)
            if family_name not in family:
                family[family_name] = []
            family[family_name].append(model_name)
        return family

    docs_models = read_models_from_docs()
    all_models = list_models(pretrained=True, exclude_filters=["*in21k"])

    all_models_family = populate_family(all_models)
    docs_models_family = populate_family(docs_models)

    for key in docs_models_family:
        del all_models_family[key]

    chosen_models = set()
    chosen_models.update(value[0] for value in docs_models_family.values())

    chosen_models.update(value[0] for key, value in all_models_family.items())

    filename = "timm_models_list.txt"
    if os.path.exists("benchmarks"):
        filename = "benchmarks/" + filename
    with open(filename, "w") as fw:
        for model_name in sorted(chosen_models):
            fw.write(model_name + "\n")


class TimmRunner(BenchmarkRunner):
    def __init__(self):
        super().__init__()
        self.suite_name = "timm_models"

    @property
    def force_amp_for_fp16_bf16_models(self):
        return FORCE_AMP_FOR_FP16_BF16_MODELS

    @property
    def force_fp16_for_bf16_models(self):
        return set()

    @property
    def get_output_amp_train_process_func(self):
        return {}

    @property
    def skip_accuracy_check_as_eager_non_deterministic(self):
        if self.args.accuracy and self.args.training:
            return SKIP_ACCURACY_CHECK_AS_EAGER_NON_DETERMINISTIC_MODELS
        return set()

    @property
    def guard_on_nn_module_models(self):
        return {
            "convit_base",
        }

    @property
    def inline_inbuilt_nn_modules_models(self):
        return {
            "lcnet_050",
        }

    @download_retry_decorator
    def _download_model(self, model_name):
        model = create_model(
            model_name,
            in_chans=3,
            scriptable=False,
            num_classes=None,
            drop_rate=0.0,
            drop_path_rate=None,
            drop_block_rate=None,
            pretrained=True,
        )
        return model

    def load_model(
        self,
        device,
        model_name,
        batch_size=None,
        extra_args=None,
    ):
        if self.args.enable_activation_checkpointing:
            raise NotImplementedError(
                "Activation checkpointing not implemented for Timm models"
            )

        is_training = self.args.training
        use_eval_mode = self.args.use_eval_mode

        channels_last = self._args.channels_last
        model = self._download_model(model_name)

        if model is None:
            raise RuntimeError(f"Failed to load model '{model_name}'")
        model.to(
            device=device,
            memory_format=torch.channels_last if channels_last else None,
        )

        self.num_classes = model.num_classes

        data_config = resolve_data_config(
            vars(self._args) if timmversion >= "0.8.0" else self._args,
            model=model,
            use_test_size=not is_training,
        )
        input_size = data_config["input_size"]
        recorded_batch_size = TIMM_MODELS[model_name]

        if model_name in BATCH_SIZE_DIVISORS:
            recorded_batch_size = max(
                int(recorded_batch_size / BATCH_SIZE_DIVISORS[model_name]), 1
            )
        batch_size = batch_size or recorded_batch_size

        torch.manual_seed(1337)
        input_tensor = torch.randint(
            256, size=(batch_size,) + input_size, device=device
        ).to(dtype=torch.float32)
        mean = torch.mean(input_tensor)
        std_dev = torch.std(input_tensor)
        example_inputs = (input_tensor - mean) / std_dev

        if channels_last:
            example_inputs = example_inputs.contiguous(
                memory_format=torch.channels_last
            )
        example_inputs = [
            example_inputs,
        ]
        self.target = self._gen_target(batch_size, device)

        self.loss = torch.nn.CrossEntropyLoss().to(device)

        if model_name in SCALED_COMPUTE_LOSS:
            self.compute_loss = self.scaled_compute_loss

        if is_training and not use_eval_mode:
            model.train()
        else:
            model.eval()

        self.validate_model(model, example_inputs)

        return device, model_name, model, example_inputs, batch_size

    def iter_model_names(self, args):
        # for model_name in list_models(pretrained=True, exclude_filters=["*in21k"]):
        model_names = sorted(TIMM_MODELS.keys())
        start, end = self.get_benchmark_indices(len(model_names))
        for index, model_name in enumerate(model_names):
            if index < start or index >= end:
                continue
            if (
                not re.search("|".join(args.filter), model_name, re.IGNORECASE)
                or re.search("|".join(args.exclude), model_name, re.IGNORECASE)
                or model_name in args.exclude_exact
                or model_name in self.skip_models
            ):
                continue

            yield model_name

    def pick_grad(self, name, is_training):
        if is_training:
            return torch.enable_grad()
        else:
            return torch.no_grad()

    def use_larger_multiplier_for_smaller_tensor(self, name):
        return name in REQUIRE_LARGER_MULTIPLIER_FOR_SMALLER_TENSOR

    def get_tolerance_and_cosine_flag(self, is_training, current_device, name):
        cosine = self.args.cosine
        tolerance = 1e-3

        if self.args.freezing and name in REQUIRE_HIGHER_TOLERANCE_FOR_FREEZING:
            # the conv-batchnorm fusion used under freezing may cause relatively
            # large numerical difference. We need are larger tolerance.
            # Check https://github.com/pytorch/pytorch/issues/120545 for context
            tolerance = 8 * 1e-2

        if is_training:
            from torch._inductor import config as inductor_config

            if name in REQUIRE_EVEN_HIGHER_TOLERANCE or (
                inductor_config.max_autotune
                and name in REQUIRE_EVEN_HIGHER_TOLERANCE_MAX_AUTOTUNE
            ):
                tolerance = 8 * 1e-2
            elif name in REQUIRE_HIGHER_TOLERANCE:
                tolerance = 4 * 1e-2
            else:
                tolerance = 1e-2
        return tolerance, cosine

    def _gen_target(self, batch_size, device):
        return torch.empty((batch_size,) + (), device=device, dtype=torch.long).random_(
            self.num_classes
        )

    def compute_loss(self, pred):
        # High loss values make gradient checking harder, as small changes in
        # accumulation order upsets accuracy checks.
        return reduce_to_scalar_loss(pred)

    def scaled_compute_loss(self, pred):
        # Loss values need zoom out further.
        return reduce_to_scalar_loss(pred) / 1000.0

    def forward_pass(self, mod, inputs, collect_outputs=True):
        with self.autocast(**self.autocast_arg):
            return mod(*inputs)

    def forward_and_backward_pass(self, mod, inputs, collect_outputs=True):
        cloned_inputs = clone_inputs(inputs)
        self.optimizer_zero_grad(mod)
        with self.autocast(**self.autocast_arg):
            pred = mod(*cloned_inputs)
            if isinstance(pred, tuple):
                pred = pred[0]
            loss = self.compute_loss(pred)
        self.grad_scaler.scale(loss).backward()
        self.optimizer_step()
        if collect_outputs:
            return collect_results(mod, pred, loss, cloned_inputs)
        return None


def timm_main():
    logging.basicConfig(level=logging.WARNING)
    warnings.filterwarnings("ignore")
    main(TimmRunner())


if __name__ == "__main__":
    timm_main()
