"""
This script uses linear programming to analyze outputs of triton mm config tuning.
To generate output that can be fed into this script set the env varTORCHINDUCTOR_MM_LOGGING_FILE.

That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates.
"""
import json

import click
import pulp


def parse_log_file(file_path):
    with open(file_path) as f:
        logs = json.load(f)

    occurrence_count = {}
    benchmark_logs = {}

    # Parse the logs
    for entry in logs:
        if "invoke" in entry:
            shape = entry["invoke"]
            if shape not in occurrence_count:
                occurrence_count[shape] = 0
            occurrence_count[shape] += 1
        else:
            for shape, timings in entry.items():
                if shape not in benchmark_logs:
                    benchmark_logs[shape] = []
                benchmark_logs[shape].extend(timings)

    return occurrence_count, benchmark_logs


def optimize_templates(N, occurrence_count, benchmark_logs, verbose=False):
    # Set of all possible Triton templates keyed by their attributes
    triton_templates = set()
    for timings in benchmark_logs.values():
        for timing in timings:
            if timing["type"] == "triton":
                triton_templates.add(
                    (
                        timing["BLOCK_M"],
                        timing["BLOCK_N"],
                        timing["BLOCK_K"],
                        timing["num_stages"],
                        timing["num_warps"],
                    )
                )

    # Print the initial data
    if verbose:
        print("Occurrence Count:", occurrence_count)
        print("Triton Templates:", triton_templates)

    # Create a dictionary to store template selection variables
    template_vars = {
        template: pulp.LpVariable(f"Template_{template}", 0, 1, pulp.LpBinary)
        for template in triton_templates
    }

    # Variables to select specific timing option for each shape
    selection_vars = {
        (shape, "cublas"): pulp.LpVariable(
            f"Select_{shape}_cublas", 0, 1, pulp.LpBinary
        )
        for shape in occurrence_count
    }
    for shape in occurrence_count:
        for template in triton_templates:
            selection_vars[(shape, template)] = pulp.LpVariable(
                f"Select_{shape}_{template}", 0, 1, pulp.LpBinary
            )

    # Variables for the total time for each shape
    min_time_vars = pulp.LpVariable.dicts(
        "MinTime", occurrence_count.keys(), 0, None, pulp.LpContinuous
    )

    # Define the problem
    prob = pulp.LpProblem("MatrixMultiplicationOptimization", pulp.LpMinimize)

    # Objective: Minimize the weighted total time
    prob += pulp.lpSum(
        [occurrence_count[shape] * min_time_vars[shape] for shape in occurrence_count]
    )

    # Constraints to select exactly N templates
    prob += pulp.lpSum([template_vars[template] for template in triton_templates]) == N

    # Store triton options per shape for debugging
    triton_options_per_shape = {}

    # Constraints for the total time for each shape
    for shape in occurrence_count:
        # Get cuBLAS time
        cublas_times = [
            timing["time"]
            for timing in benchmark_logs[shape]
            if timing["type"] == "cublas"
        ]
        min_cublas_time = min(cublas_times)

        # Collect Triton options
        triton_options = []
        for template in triton_templates:
            triton_times = [
                timing["time"]
                for timing in benchmark_logs[shape]
                if timing["type"] == "triton"
                and (
                    timing["BLOCK_M"],
                    timing["BLOCK_N"],
                    timing["BLOCK_K"],
                    timing["num_stages"],
                    timing["num_warps"],
                )
                == template
            ]
            if triton_times:
                min_triton_time = min(triton_times)
                triton_options.append((min_triton_time, template))

        # Save triton options for debugging
        triton_options_per_shape[shape] = triton_options

        # Ensure exactly one timing option is selected for each shape
        prob += (
            pulp.lpSum(
                [selection_vars[(shape, "cublas")]]
                + [
                    selection_vars[(shape, template)]
                    for triton_time, template in triton_options
                ]
            )
            == 1
        )

        # Ensure min_time_vars[shape] matches the selected timing option
        prob += min_time_vars[shape] == (
            selection_vars[(shape, "cublas")] * min_cublas_time
            + pulp.lpSum(
                [
                    selection_vars[(shape, template)] * triton_time
                    for triton_time, template in triton_options
                ]
            )
        )

        # Ensure Triton templates can only be selected if they are included in the N allowed templates
        for triton_time, template in triton_options:
            prob += selection_vars[(shape, template)] <= template_vars[template]

    # Print the constraints
    if verbose:
        print("Constraints:")
        for constraint in prob.constraints.values():
            print(constraint)

    # Solve the problem with suppressed output
    prob.solve(pulp.PULP_CBC_CMD(msg=False))

    # Output the selected templates and their configurations
    selected_templates = [
        template
        for template in triton_templates
        if pulp.value(template_vars[template]) == 1
    ]
    total_time = sum(
        pulp.value(min_time_vars[shape]) * occurrence_count[shape]
        for shape in occurrence_count
    )

    # Print the values of the decision variables after solving
    if verbose:
        print("Decision Variable Values:")
        for var in prob.variables():
            print(f"{var.name} = {var.varValue}")

    # # Debugging information
    if verbose:
        for shape in occurrence_count:
            print(f"Shape: {shape}")
            print(f"  Min Time: {pulp.value(min_time_vars[shape])}")
            print(f"  Occurrences: {occurrence_count[shape]}")
            print(
                f"  Min CuBLAS Time: {min_cublas_time} Selected: {pulp.value(selection_vars[(shape, 'cublas')])}"
            )
            for triton_time, template in triton_options_per_shape[shape]:
                print(
                    f"  Triton Template: {template} Time: {triton_time} Selected: {pulp.value(selection_vars[(shape, template)])}"
                )

    return selected_templates, total_time


# Main code to parse the log file and optimize templates
@click.command()
@click.argument("filename")
@click.option("--min-templates", default=0, help="Minimum number of templates.")
@click.option("--max-templates", default=10, help="Maximum number of templates.")
@click.option("--verbose", is_flag=True, help="Enable verbose output.")
def main(filename, min_templates, max_templates, verbose):
    occurrence_count, benchmark_logs = parse_log_file(filename)
    times = []
    for N in range(min_templates, max_templates + 1):
        selected_templates, total_time = optimize_templates(
            N, occurrence_count, benchmark_logs, verbose
        )
        print(f"N = {N}")
        print(f"Selected Templates: {selected_templates}")
        print(f"Total Weighted Time: {total_time}")
        times.append(total_time)
    print(times)


if __name__ == "__main__":
    main()
