#!/usr/bin/python3
#
# Copyright (C) 2022 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Generate java benchmarks for 2239-varhandle-perf
"""
# TODO: fix constants when converting the test to a Golem benchmark


from enum import Enum
from pathlib import Path

import io
import sys


class MemLoc(Enum):
    FIELD = 0
    ARRAY = 1
    BYTE_ARRAY_VIEW = 2


def to_camel_case(word):
    return ''.join(c for c in word.title() if not c == '_')


class Benchmark:
    def __init__(self, code, static, vartype, flavour, klass, method, memloc,
        byteorder="LITTLE_ENDIAN"):
        self.code = code
        self.static = static
        self.vartype = vartype
        self.flavour = flavour
        self.klass = klass
        self.method = method
        self.byteorder = byteorder
        self.memloc = memloc

    def fullname(self):
        return "{klass}{method}{flavour}{static_name}{memloc}{byteorder}{vartype}Benchmark".format(
            klass = self.klass,
            method = to_camel_case(self.method),
            flavour = self.flavour,
            static_name = "Static" if self.static else "",
            memloc = to_camel_case(self.memloc.name),
            byteorder = to_camel_case(self.byteorder),
            vartype = to_camel_case(self.vartype))

    def gencode(self):
        if self.klass == "Reflect":
            method_suffix = "" if self.vartype == "String" else self.vartype.title()
            static_first_arg = "null"
        elif self.klass == "Unsafe":
            method_suffix = "Object" if self.vartype == "String" else self.vartype.title()
            static_first_arg = "this.getClass()"
        else:
            method_suffix = ""
            static_first_arg = ""

        first_arg = static_first_arg if self.static else "this"

        return self.code.format(
            name = self.fullname(),
            method = self.method + method_suffix,
            flavour = self.flavour,
            static_name = "Static" if self.static else "",
            static_kwd = "static " if self.static else "",
            this = first_arg,
            this_comma = "" if not first_arg else first_arg + ", ",
            vartype = self.vartype,
            byteorder = self.byteorder,
            value1 = VALUES[self.vartype][0],
            value2 = VALUES[self.vartype][1],
            value1_byte_array = VALUES["byte[]"][self.byteorder][0],
            value2_byte_array = VALUES["byte[]"][self.byteorder][1],
            loop = "for (int pass = 0; pass < 100; ++pass)",
            iters = ITERATIONS)


def BenchVHField(code, static, vartype, flavour, method):
    return Benchmark(code, static, vartype, flavour, "VarHandle", method, MemLoc.FIELD)


def BenchVHArray(code, vartype, flavour, method):
    return Benchmark(code, False, vartype, flavour, "VarHandle", method, MemLoc.ARRAY)


def BenchVHByteArrayView(code, byteorder, vartype, flavour, method):
    return Benchmark(code, False, vartype, flavour, "VarHandle", method, MemLoc.BYTE_ARRAY_VIEW, byteorder)


def BenchReflect(code, static, vartype, method):
    return Benchmark(code, static, vartype, "", "Reflect", method, MemLoc.FIELD)


def BenchUnsafe(code, static, vartype, method):
    return Benchmark(code, static, vartype, "", "Unsafe", method, MemLoc.FIELD)


VALUES = {
    "int": ["42", "~42"],
    "float": ["3.14f", "2.17f"],
    "String": ["\"qwerty\"", "null"],
    "byte[]": {
        "LITTLE_ENDIAN": [
            "{ (byte) VALUE, (byte) (VALUE >> 8), (byte) (VALUE >> 16), (byte) (VALUE >> 24) }",
            "{ (byte) VALUE, (byte) (-1 >> 8), (byte) (-1 >> 16), (byte) (-1 >> 24) }",
        ],
        "BIG_ENDIAN": [
            "{ (byte) (VALUE >> 24), (byte) (VALUE >> 16), (byte) (VALUE >> 8), (byte) VALUE }",
            "{ (byte) (-1 >> 24), (byte) (-1 >> 16), (byte) (-1 >> 8), (byte) VALUE }",
        ],
    },
}


# TODO: fix these numbers when converting the test to a Golem benchmark
ITERATIONS = 1 # 3000 for a real benchmark
REPEAT = 2 # 30 for a real benchmark
REPEAT_HALF = (int) (REPEAT / 2)


BANNER = '// This file is generated by util-src/generate_java.py do not directly modify!'


VH_IMPORTS = """
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
"""


VH_START = BANNER + VH_IMPORTS + """
class {name} extends MicroBenchmark {{
  static final {vartype} FIELD_VALUE = {value1};
  {static_kwd}{vartype} field = FIELD_VALUE;
  VarHandle vh;

  {name}() throws Throwable {{
    vh = MethodHandles.lookup().find{static_name}VarHandle(this.getClass(), "field", {vartype}.class);
  }}
"""


END = """
    }}
  }}

  @Override
  public int innerIterations() {{
      return {iters};
  }}
}}"""


VH_GET = VH_START + """
  @Override
  public void setup() {{
    {vartype} v = ({vartype}) vh.{method}{flavour}({this});
    if (v != FIELD_VALUE) {{
      throw new RuntimeException("field has unexpected value " + v);
    }}
  }}

  @Override
  public void run() {{
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}({this});""" * REPEAT + END


VH_SET = VH_START + """
  @Override
  public void teardown() {{
    if (field != FIELD_VALUE) {{
      throw new RuntimeException("field has unexpected value " + field);
    }}
  }}

  @Override
  public void run() {{
    {vartype} x;
    {loop} {{""" + """
      vh.{method}{flavour}({this_comma}FIELD_VALUE);""" * REPEAT + END


VH_CAS = VH_START + """
  @Override
  public void run() {{
    boolean success;
    {loop} {{""" + """
      success = vh.{method}{flavour}({this_comma}field, {value2});
      success = vh.{method}{flavour}({this_comma}field, {value1});""" * REPEAT_HALF + END


VH_CAE = VH_START + """
  @Override
  public void run() {{
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}({this_comma}field, {value2});
      x = ({vartype}) vh.{method}{flavour}({this_comma}field, {value1});""" * REPEAT_HALF + END


VH_GAS = VH_START + """
  @Override
  public void run() {{
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}({this_comma}{value2});""" * REPEAT + END


VH_GAA = VH_START + """
  @Override
  public void run() {{
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}({this_comma}{value2});""" * REPEAT + END


VH_GAB = VH_START + """
  @Override
  public void run() {{
    int x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}({this_comma}{value2});""" * REPEAT + END


VH_START_ARRAY = BANNER + VH_IMPORTS + """
class {name} extends MicroBenchmark {{
  static final {vartype} ELEMENT_VALUE = {value1};
  {vartype}[] array = {{ ELEMENT_VALUE }};
  VarHandle vh;

  {name}() throws Throwable {{
    vh = MethodHandles.arrayElementVarHandle({vartype}[].class);
  }}
"""


VH_GET_A = VH_START_ARRAY + """
  @Override
  public void setup() {{
    {vartype} v = ({vartype}) vh.{method}{flavour}(array, 0);
    if (v != ELEMENT_VALUE) {{
      throw new RuntimeException("array element has unexpected value: " + v);
    }}
  }}

  @Override
  public void run() {{
    {vartype}[] a = array;
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}(a, 0);""" * REPEAT + END


VH_SET_A = VH_START_ARRAY + """
  @Override
  public void teardown() {{
    if (array[0] != {value2}) {{
      throw new RuntimeException("array element has unexpected value: " + array[0]);
    }}
  }}

  @Override
  public void run() {{
    {vartype}[] a = array;
    {vartype} x;
    {loop} {{""" + """
      vh.{method}{flavour}(a, 0, {value2});""" * REPEAT + END


VH_START_BYTE_ARRAY_VIEW = BANNER + VH_IMPORTS + """
import java.util.Arrays;
import java.nio.ByteOrder;

class {name} extends MicroBenchmark {{
  static final {vartype} VALUE = {value1};
  byte[] array1 = {value1_byte_array};
  byte[] array2 = {value2_byte_array};
  VarHandle vh;

  {name}() throws Throwable {{
    vh = MethodHandles.byteArrayViewVarHandle({vartype}[].class, ByteOrder.{byteorder});
  }}
"""


VH_GET_BAV = VH_START_BYTE_ARRAY_VIEW + """
  @Override
  public void setup() {{
    {vartype} v = ({vartype}) vh.{method}{flavour}(array1, 0);
    if (v != VALUE) {{
      throw new RuntimeException("array has unexpected value: " + v);
    }}
  }}

  @Override
  public void run() {{
    byte[] a = array1;
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) vh.{method}{flavour}(a, 0);""" * REPEAT + END


VH_SET_BAV = VH_START_BYTE_ARRAY_VIEW + """
  @Override
  public void teardown() {{
    if (!Arrays.equals(array2, array1)) {{
      throw new RuntimeException("array has unexpected values: " +
          array2[0] + " " + array2[1] + " " + array2[2] + " " + array2[3]);
    }}
  }}

  @Override
  public void run() {{
    byte[] a = array2;
    {loop} {{""" + """
      vh.{method}{flavour}(a, 0, VALUE);""" * REPEAT + END


REFLECT_START = BANNER + """
import java.lang.reflect.Field;

class {name} extends MicroBenchmark {{
  Field field;
  {static_kwd}{vartype} value;

  {name}() throws Throwable {{
    field = this.getClass().getDeclaredField("value");
  }}
"""


REFLECT_GET = REFLECT_START + """
  @Override
  public void run() throws Throwable {{
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) field.{method}({this});""" * REPEAT + END


REFLECT_SET = REFLECT_START + """
  @Override
  public void run() throws Throwable {{
    {loop} {{""" + """
      field.{method}({this_comma}{value1});""" * REPEAT + END


UNSAFE_START = BANNER + """
import java.lang.reflect.Field;
import jdk.internal.misc.Unsafe;

class {name} extends UnsafeMicroBenchmark {{
  long offset;
  {static_kwd}{vartype} value = {value1};

  {name}() throws Throwable {{
    Field field = this.getClass().getDeclaredField("value");
    offset = get{static_name}FieldOffset(field);
  }}
"""


UNSAFE_GET = UNSAFE_START + """
  @Override
  public void run() throws Throwable {{
    {vartype} x;
    {loop} {{""" + """
      x = ({vartype}) theUnsafe.{method}({this_comma}offset);""" * REPEAT + END


UNSAFE_PUT = UNSAFE_START + """
  @Override
  public void run() throws Throwable {{
    {loop} {{""" + """
      theUnsafe.{method}({this_comma}offset, {value1});""" * REPEAT + END


UNSAFE_CAS = UNSAFE_START + """
  @Override
  public void run() throws Throwable {{
    {loop} {{""" + """
      theUnsafe.{method}({this_comma}offset, {value1}, {value2});
      theUnsafe.{method}({this_comma}offset, {value2}, {value1});""" * REPEAT_HALF + END

def benchmark_selector(benchmark_to_run):
  if benchmark_to_run == '0':
    return (
        [BenchVHField(VH_GET, static, vartype, flavour, "get")
            for flavour in ["", "Acquire", "Opaque", "Volatile"]
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '1':
    return (
        [BenchVHField(VH_SET, static, vartype, flavour, "set")
            for flavour in ["", "Volatile", "Opaque", "Release"]
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '2':
    return (
        [BenchVHField(VH_CAS, static, vartype, flavour, "compareAndSet")
            for flavour in [""]
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '3':
    return (
        [BenchVHField(VH_CAS, static, vartype, flavour, "weakCompareAndSet")
            for flavour in ["", "Plain", "Acquire", "Release"]
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '4':
    return (
        [BenchVHField(VH_CAE, static, vartype, flavour, "compareAndExchange")
            for flavour in ["", "Acquire", "Release"]
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '5':
    return (
        [BenchVHField(VH_GAS, static, vartype, flavour, "getAndSet")
            for flavour in ["", "Acquire", "Release"]
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '6':
    return (
        [BenchVHField(VH_GAA, static, vartype, flavour, "getAndAdd")
            for flavour in ["", "Acquire", "Release"]
            for static in [True, False]
            for vartype in ["int", "float"]])
  elif benchmark_to_run == '7':
    return (
        [BenchVHField(VH_GAB, static, vartype, flavour, "getAndBitwise")
            for flavour in [oper + mode
                for oper in ["Or", "Xor", "And"]
                for mode in ["", "Acquire", "Release"]]
            for static in [True, False]
            for vartype in ["int"]])
  elif benchmark_to_run == '8':
    return (
        [BenchVHArray(VH_GET_A, vartype, flavour, "get")
            for flavour in [""]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '9':
    return (
        [BenchVHArray(VH_SET_A, vartype, flavour, "set")
            for flavour in [""]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '10':
    return (
        [BenchVHByteArrayView(VH_GET_BAV, byteorder, vartype, flavour, "get")
            for flavour in [""]
            for byteorder in ["BIG_ENDIAN", "LITTLE_ENDIAN"]
            for vartype in ["int"]])
  elif benchmark_to_run == '11':
    return (
        [BenchVHByteArrayView(VH_SET_BAV, byteorder, vartype, flavour, "set")
            for flavour in [""]
            for byteorder in ["BIG_ENDIAN", "LITTLE_ENDIAN"]
            for vartype in ["int"]])
  elif benchmark_to_run == '12':
    return (
        [BenchReflect(REFLECT_GET, static, vartype, "get")
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '13':
    return (
        [BenchReflect(REFLECT_SET, static, vartype, "set")
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '14':
    return (
        [BenchUnsafe(UNSAFE_GET, static, vartype, "get")
            for static in [True, False]
            for vartype in ["int", "String"]])
  elif benchmark_to_run == '15':
    return (
        [BenchUnsafe(UNSAFE_PUT, static, vartype, "put")
            for static in [True, False]
            for vartype in ["int", "String"]])
  else:
    return (
        [BenchUnsafe(UNSAFE_CAS, static, vartype, method)
            for method in ["compareAndSwap", "compareAndSet"]
            for static in [True, False]
            for vartype in ["int", "String"]])

def main(argv):
    final_java_dir = Path(argv[1])
    if not final_java_dir.exists() or not final_java_dir.is_dir():
        print("{} is not a valid java dir".format(final_java_dir), file=sys.stderr)
        sys.exit(1)

    benchmark_to_run = argv[2]
    ALL_BENCHMARKS = benchmark_selector(benchmark_to_run)

    MAIN = BANNER + """
    public class Main {
      static MicroBenchmark[] benchmarks;

      private static void initialize() throws Throwable {
        benchmarks = new MicroBenchmark[] {""" + "".join(["""
          new {}(),""".format(b.fullname()) for b in ALL_BENCHMARKS]) + """
        };
      }

      public static void main(String[] args) throws Throwable {
        initialize();
        for (MicroBenchmark benchmark : benchmarks) {
          benchmark.report();
        }
      }
    }"""

    for bench in ALL_BENCHMARKS:
        file_path = final_java_dir / "{}.java".format(bench.fullname())
        with file_path.open("w") as f:
            print(bench.gencode(), file=f)

    file_path = final_java_dir / "Main.java"
    with file_path.open("w") as f:
        print(MAIN, file=f)


if __name__ == '__main__':
    main(sys.argv)
