"""
This script will generate default values of quantization configs.
These are for use in the documentation.
"""

import os.path

import torch
from torch.ao.quantization.backend_config import get_native_backend_config_dict
from torch.ao.quantization.backend_config.utils import (
    entry_to_pretty_str,
    remove_boolean_dispatch_from_name,
)


# Create a directory for the images, if it doesn't exist
QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH = os.path.join(
    os.path.realpath(os.path.join(__file__, "..")), "quantization_backend_configs"
)

if not os.path.exists(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH):
    os.mkdir(QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH)

output_path = os.path.join(
    QUANTIZATION_BACKEND_CONFIG_IMAGE_PATH, "default_backend_config.txt"
)

with open(output_path, "w") as f:
    native_backend_config_dict = get_native_backend_config_dict()

    configs = native_backend_config_dict["configs"]

    def _sort_key_func(entry):
        pattern = entry["pattern"]
        while isinstance(pattern, tuple):
            pattern = pattern[-1]

        pattern = remove_boolean_dispatch_from_name(pattern)
        if not isinstance(pattern, str):
            # methods are already strings
            pattern = torch.typename(pattern)

        # we want
        #
        #   torch.nn.modules.pooling.AdaptiveAvgPool1d
        #
        # and
        #
        #   torch._VariableFunctionsClass.adaptive_avg_pool1d
        #
        # to be next to each other, so convert to all lower case
        # and remove the underscores, and compare the last part
        # of the string
        pattern_str_normalized = pattern.lower().replace("_", "")
        key = pattern_str_normalized.split(".")[-1]
        return key

    configs.sort(key=_sort_key_func)

    entries = []
    for entry in configs:
        entries.append(entry_to_pretty_str(entry))
    entries = ",\n".join(entries)
    f.write(entries)
