import csv
from abc import ABC, abstractmethod

from fbscribelogger import make_scribe_logger

import torch._C._instruction_counter as i_counter
import torch._dynamo.config as config
from torch._dynamo.utils import CompileTimeInstructionCounter


scribe_log_torch_benchmark_compile_time = make_scribe_logger(
    "TorchBenchmarkCompileTime",
    """
struct TorchBenchmarkCompileTimeLogEntry {

  # The commit SHA that triggered the workflow, e.g., 02a6b1d30f338206a71d0b75bfa09d85fac0028a. Derived from GITHUB_SHA.
  4: optional string commit_sha;

  # The unit timestamp in second for the Scuba Time Column override
  6: optional i64 time;
  7: optional i64 instruction_count; # Instruction count of compilation step
  8: optional string name; # Benchmark name

  # Commit date (not author date) of the commit in commit_sha as timestamp, e.g., 1724208105.  Increasing if merge bot is used, though not monotonic; duplicates occur when stack is landed.
  16: optional i64 commit_date;

  # A unique number for each workflow run within a repository, e.g., 19471190684. Derived from GITHUB_RUN_ID.
  17: optional string github_run_id;

  # A unique number for each attempt of a particular workflow run in a repository, e.g., 1. Derived from GITHUB_RUN_ATTEMPT.
  18: optional string github_run_attempt;

  # Indicates if branch protections or rulesets are configured for the ref that triggered the workflow run. Derived from GITHUB_REF_PROTECTED.
  20: optional bool github_ref_protected;

  # The fully-formed ref of the branch or tag that triggered the workflow run, e.g., refs/pull/133891/merge or refs/heads/main. Derived from GITHUB_REF.
  21: optional string github_ref;

  # The weight of the record according to current sampling rate
  25: optional i64 weight;

  # The name of the current job. Derived from JOB_NAME, e.g., linux-jammy-py3.8-gcc11 / test (default, 3, 4, amz2023.linux.2xlarge).
  26: optional string github_job;

  # The GitHub user who triggered the job.  Derived from GITHUB_TRIGGERING_ACTOR.
  27: optional string github_triggering_actor;

  # A unique number for each run of a particular workflow in a repository, e.g., 238742. Derived from GITHUB_RUN_NUMBER.
  28: optional string github_run_number_str;
}
""",  # noqa: B950
)


class BenchmarkBase(ABC):
    # measure total number of instruction spent in _work.
    _enable_instruction_count = False

    # measure total number of instruction spent in convert_frame.compile_inner
    # TODO is there other parts we need to add ?
    _enable_compile_time_instruction_count = False

    def enable_instruction_count(self):
        self._enable_instruction_count = True
        return self

    def enable_compile_time_instruction_count(self):
        self._enable_compile_time_instruction_count = True
        return self

    def name(self):
        return ""

    def description(self):
        return ""

    @abstractmethod
    def _prepare(self):
        pass

    @abstractmethod
    def _work(self):
        pass

    def _prepare_once(self):  # noqa: B027
        pass

    def _count_instructions(self):
        print(f"collecting instruction count for {self.name()}")
        results = []
        for i in range(10):
            self._prepare()
            id = i_counter.start()
            self._work()
            count = i_counter.end(id)
            print(f"instruction count for iteration {i} is {count}")
            results.append(count)
        return min(results)

    def _count_compile_time_instructions(self):
        print(f"collecting compile time instruction count for {self.name()}")
        config.record_compile_time_instruction_count = True

        results = []
        for i in range(10):
            self._prepare()
            # CompileTimeInstructionCounter.record is only called on convert_frame._compile_inner
            # hence this will only count instruction count spent in compile_inner.
            CompileTimeInstructionCounter.clear()
            self._work()
            count = CompileTimeInstructionCounter.value()
            if count == 0:
                raise RuntimeError(
                    "compile time instruction count is 0, please check your benchmarks"
                )
            print(f"compile time instruction count for iteration {i} is {count}")
            results.append(count)

        config.record_compile_time_instruction_count = False
        return min(results)

    def append_results(self, path):
        with open(path, "a", newline="") as csvfile:
            # Create a writer object
            writer = csv.writer(csvfile)
            # Write the data to the CSV file
            for entry in self.results:
                writer.writerow(entry)

    def print(self):
        for entry in self.results:
            print(f"{entry[0]},{entry[1]},{entry[2]}")

    def collect_all(self):
        self._prepare_once()
        self.results = []
        if (
            self._enable_instruction_count
            and self._enable_compile_time_instruction_count
        ):
            raise RuntimeError(
                "not supported until we update the logger, both logs to the same field now"
            )

        if self._enable_instruction_count:
            r = self._count_instructions()
            self.results.append((self.name(), "instruction_count", r))
            scribe_log_torch_benchmark_compile_time(
                name=self.name(),
                instruction_count=r,
            )
        if self._enable_compile_time_instruction_count:
            r = self._count_compile_time_instructions()

            self.results.append(
                (
                    self.name(),
                    "compile_time_instruction_count",
                    r,
                )
            )
            # TODO add a new field compile_time_instruction_count to the logger.
            scribe_log_torch_benchmark_compile_time(
                name=self.name(),
                instruction_count=r,
            )
        return self
