// Copyright 2021 Code Intelligence GmbH
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.code_intelligence.jazzer.runtime;

import com.code_intelligence.jazzer.api.HookType;
import com.code_intelligence.jazzer.api.MethodHook;
import java.lang.invoke.MethodHandle;
import java.util.*;

@SuppressWarnings("unused")
final public class TraceCmpHooks {
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte", targetMethod = "compare",
      targetMethodDescriptor = "(BB)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte",
      targetMethod = "compareUnsigned", targetMethodDescriptor = "(BB)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short", targetMethod = "compare",
      targetMethodDescriptor = "(SS)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short",
      targetMethod = "compareUnsigned", targetMethodDescriptor = "(SS)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
      targetMethod = "compare", targetMethodDescriptor = "(II)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
      targetMethod = "compareUnsigned", targetMethodDescriptor = "(II)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ",
      targetMethod = "compare", targetMethodDescriptor = "(II)I")
  public static void
  integerCompare(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) {
    TraceDataFlowNativeCallbacks.traceCmpInt((int) arguments[0], (int) arguments[1], hookId);
  }

  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Byte",
      targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Byte;)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Short",
      targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Short;)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Integer",
      targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Integer;)I")
  public static void
  integerCompareTo(MethodHandle method, Object thisObject, Object[] arguments, int hookId) {
    TraceDataFlowNativeCallbacks.traceCmpInt((int) thisObject, (int) arguments[0], hookId);
  }

  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long", targetMethod = "compare",
      targetMethodDescriptor = "(JJ)I")
  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long",
      targetMethod = "compareUnsigned", targetMethodDescriptor = "(JJ)I")
  public static void
  longCompare(MethodHandle method, Object thisObject, Object[] arguments, int hookId) {
    TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId);
  }

  @MethodHook(type = HookType.BEFORE, targetClassName = "java.lang.Long",
      targetMethod = "compareTo", targetMethodDescriptor = "(Ljava/lang/Long;)I")
  public static void
  longCompareTo(MethodHandle method, Long thisObject, Object[] arguments, int hookId) {
    TraceDataFlowNativeCallbacks.traceCmpLong(thisObject, (long) arguments[0], hookId);
  }

  @MethodHook(type = HookType.BEFORE, targetClassName = "kotlin.jvm.internal.Intrinsics ",
      targetMethod = "compare", targetMethodDescriptor = "(JJ)I")
  public static void
  longCompareKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId) {
    TraceDataFlowNativeCallbacks.traceCmpLong((long) arguments[0], (long) arguments[1], hookId);
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "equals")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
      targetMethod = "equalsIgnoreCase")
  public static void
  equals(MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean areEqual) {
    if (!areEqual && arguments.length == 1 && arguments[0] instanceof String) {
      // The precise value of the result of the comparison is not used by libFuzzer as long as it is
      // non-zero.
      TraceDataFlowNativeCallbacks.traceStrcmp(thisObject, (String) arguments[0], 1, hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Object", targetMethod = "equals")
  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.CharSequence", targetMethod = "equals")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.Number", targetMethod = "equals")
  public static void
  genericEquals(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) {
    if (!areEqual && arguments.length == 1 && arguments[0] != null
        && thisObject.getClass() == arguments[0].getClass()) {
      TraceDataFlowNativeCallbacks.traceGenericCmp(thisObject, arguments[0], hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "clojure.lang.Util", targetMethod = "equiv")
  public static void genericStaticEquals(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean areEqual) {
    if (!areEqual && arguments.length == 2 && arguments[0] != null && arguments[1] != null
        && arguments[1].getClass() == arguments[0].getClass()) {
      TraceDataFlowNativeCallbacks.traceGenericCmp(arguments[0], arguments[1], hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "compareTo")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
      targetMethod = "compareToIgnoreCase")
  public static void
  compareTo(
      MethodHandle method, String thisObject, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue != 0 && arguments.length == 1 && arguments[0] instanceof String) {
      TraceDataFlowNativeCallbacks.traceStrcmp(
          thisObject, (String) arguments[0], returnValue, hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contentEquals")
  public static void
  contentEquals(MethodHandle method, String thisObject, Object[] arguments, int hookId,
      Boolean areEqualContents) {
    if (!areEqualContents && arguments.length == 1 && arguments[0] instanceof CharSequence) {
      TraceDataFlowNativeCallbacks.traceStrcmp(
          thisObject, ((CharSequence) arguments[0]).toString(), 1, hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
      targetMethod = "regionMatches", targetMethodDescriptor = "(ZILjava/lang/String;II)Z")
  public static void
  regionsMatches5(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
    if (!returnValue) {
      int toffset = (int) arguments[1];
      String other = (String) arguments[2];
      int ooffset = (int) arguments[3];
      int len = (int) arguments[4];
      regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String",
      targetMethod = "regionMatches", targetMethodDescriptor = "(ILjava/lang/String;II)Z")
  public static void
  regionMatches4(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
    if (!returnValue) {
      int toffset = (int) arguments[0];
      String other = (String) arguments[1];
      int ooffset = (int) arguments[2];
      int len = (int) arguments[3];
      regionMatchesInternal((String) thisObject, toffset, other, ooffset, len, hookId);
    }
  }

  private static void regionMatchesInternal(
      String thisString, int toffset, String other, int ooffset, int len, int hookId) {
    if (toffset < 0 || ooffset < 0)
      return;
    int cappedThisStringEnd = Math.min(toffset + len, thisString.length());
    int cappedOtherStringEnd = Math.min(ooffset + len, other.length());
    String thisPart = thisString.substring(toffset, cappedThisStringEnd);
    String otherPart = other.substring(ooffset, cappedOtherStringEnd);
    TraceDataFlowNativeCallbacks.traceStrcmp(thisPart, otherPart, 1, hookId);
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "contains")
  public static void
  contains(
      MethodHandle method, String thisObject, Object[] arguments, int hookId, Boolean doesContain) {
    if (!doesContain && arguments.length == 1 && arguments[0] instanceof CharSequence) {
      TraceDataFlowNativeCallbacks.traceStrstr(
          thisObject, ((CharSequence) arguments[0]).toString(), hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "indexOf")
  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "lastIndexOf")
  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.StringBuffer", targetMethod = "indexOf")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuffer",
      targetMethod = "lastIndexOf")
  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.StringBuilder", targetMethod = "indexOf")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.StringBuilder",
      targetMethod = "lastIndexOf")
  public static void
  indexOf(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue == -1 && arguments.length >= 1 && arguments[0] instanceof String) {
      TraceDataFlowNativeCallbacks.traceStrstr(
          thisObject.toString(), (String) arguments[0], hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "startsWith")
  @MethodHook(
      type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "endsWith")
  public static void
  startsWith(MethodHandle method, String thisObject, Object[] arguments, int hookId,
      Boolean doesStartOrEndsWith) {
    if (!doesStartOrEndsWith && arguments.length >= 1 && arguments[0] instanceof String) {
      TraceDataFlowNativeCallbacks.traceStrstr(thisObject, (String) arguments[0], hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.lang.String", targetMethod = "replace",
      targetMethodDescriptor =
          "(Ljava/lang/CharSequence;Ljava/lang/CharSequence;)Ljava/lang/String;")
  public static void
  replace(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, String returnValue) {
    String original = (String) thisObject;
    // Report only if the replacement was not successful.
    if (original.equals(returnValue)) {
      String target = arguments[0].toString();
      TraceDataFlowNativeCallbacks.traceStrstr(original, target, hookId);
    }
  }

  // For standard Kotlin packages, which are named according to the pattern kotlin.*, we append a
  // whitespace to the package name of the target class so that they are not mangled due to shading.
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.jvm.internal.Intrinsics ",
      targetMethod = "areEqual")
  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "equals")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "equals$default")
  public static void
  equalsKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
      Boolean equalStrings) {
    if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof String
        && arguments[1] instanceof String) {
      TraceDataFlowNativeCallbacks.traceStrcmp(
          (String) arguments[0], (String) arguments[1], 1, hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "contentEquals")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "contentEquals$default")
  public static void
  contentEqualKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
      Boolean equalStrings) {
    if (!equalStrings && arguments.length >= 2 && arguments[0] instanceof CharSequence
        && arguments[1] instanceof CharSequence) {
      TraceDataFlowNativeCallbacks.traceStrcmp(
          arguments[0].toString(), arguments[1].toString(), 1, hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "compareTo")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "compareTo$default")
  public static void
  compareToKt(
      MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue != 0 && arguments.length >= 2 && arguments[0] instanceof String
        && arguments[1] instanceof String) {
      TraceDataFlowNativeCallbacks.traceStrcmp(
          (String) arguments[0], (String) arguments[1], 1, hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "endsWith")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "endsWith$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "startsWith")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "startsWith$default")
  public static void
  startsWithKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
      Boolean doesStartOrEndsWith) {
    if (!doesStartOrEndsWith && arguments.length >= 2 && arguments[0] instanceof CharSequence
        && arguments[1] instanceof CharSequence) {
      TraceDataFlowNativeCallbacks.traceStrstr(
          arguments[0].toString(), arguments[1].toString(), hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "contains")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "contains$default")
  public static void
  containsKt(
      MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Boolean doesContain) {
    if (!doesContain && arguments.length >= 2 && arguments[0] instanceof CharSequence
        && arguments[1] instanceof CharSequence) {
      TraceDataFlowNativeCallbacks.traceStrstr(
          arguments[0].toString(), arguments[1].toString(), hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "indexOf")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "indexOf$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "lastIndexOf")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "lastIndexOf$default")
  public static void
  indexOfKt(
      MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue != -1 || arguments.length < 2 || !(arguments[0] instanceof CharSequence)) {
      return;
    }
    if (arguments[1] instanceof String) {
      TraceDataFlowNativeCallbacks.traceStrstr(
          arguments[0].toString(), (String) arguments[1], hookId);
    } else if (arguments[1] instanceof Character) {
      TraceDataFlowNativeCallbacks.traceStrstr(
          arguments[0].toString(), ((Character) arguments[1]).toString(), hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "replace")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replace$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceAfter")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceAfter$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceAfterLast")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceAfterLast$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceBefore")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceBefore$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceBeforeLast")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceBeforeLast$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceFirst")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "replaceFirst$default")
  public static void
  replaceKt(
      MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, String returnValue) {
    if (arguments.length < 2 || !(arguments[0] instanceof String)) {
      return;
    }
    String original = (String) arguments[0];
    if (!original.equals(returnValue)) {
      return;
    }

    // We currently don't handle the overloads that take a regex as a second argument.
    if (arguments[1] instanceof String || arguments[1] instanceof Character) {
      TraceDataFlowNativeCallbacks.traceStrstr(original, arguments[1].toString(), hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "regionMatches",
      targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZ)Z")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "regionMatches$default",
      targetMethodDescriptor = "(Ljava/lang/String;ILjava/lang/String;IIZILjava/lang/Object;)Z")
  public static void
  regionMatchesKt(MethodHandle method, Object alwaysNull, Object[] arguments, int hookId,
      Boolean doesRegionMatch) {
    if (!doesRegionMatch) {
      String thisString = arguments[0].toString();
      int thisOffset = (int) arguments[1];
      String other = arguments[2].toString();
      int otherOffset = (int) arguments[3];
      int length = (int) arguments[4];
      regionMatchesInternal(thisString, thisOffset, other, otherOffset, length, hookId);
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "indexOfAny")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "indexOfAny$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "lastIndexOfAny")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "lastIndexOfAny$default")
  public static void
  indexOfAnyKt(
      MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue == -1 && arguments.length >= 2 && arguments[0] instanceof CharSequence) {
      guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId);
    }
  }

  @MethodHook(
      type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ", targetMethod = "findAnyOf")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "findAnyOf$default")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "findLastAnyOf")
  @MethodHook(type = HookType.AFTER, targetClassName = "kotlin.text.StringsKt ",
      targetMethod = "findLastAnyOf$default")
  public static void
  findAnyKt(
      MethodHandle method, Object alwaysNull, Object[] arguments, int hookId, Object returnValue) {
    if (returnValue == null && arguments.length >= 2 && arguments[0] instanceof CharSequence) {
      guideTowardContainmentOfFirstElement(arguments[0].toString(), arguments[1], hookId);
    }
  }

  private static void guideTowardContainmentOfFirstElement(
      String containingString, Object candidateCollectionObj, int hookId) {
    if (candidateCollectionObj instanceof Collection<?>) {
      Collection<?> strings = (Collection<?>) candidateCollectionObj;
      if (strings.isEmpty()) {
        return;
      }
      Object firstElementObj = strings.iterator().next();
      if (firstElementObj instanceof CharSequence) {
        TraceDataFlowNativeCallbacks.traceStrstr(
            containingString, firstElementObj.toString(), hookId);
      }
    } else if (candidateCollectionObj.getClass().isArray()) {
      if (candidateCollectionObj.getClass().getComponentType() == char.class) {
        char[] chars = (char[]) candidateCollectionObj;
        if (chars.length > 0) {
          TraceDataFlowNativeCallbacks.traceStrstr(
              containingString, Character.toString(chars[0]), hookId);
        }
      }
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals",
      targetMethodDescriptor = "([B[B)Z")
  public static void
  arraysEquals(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
    if (returnValue)
      return;
    byte[] first = (byte[]) arguments[0];
    byte[] second = (byte[]) arguments[1];
    TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId);
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "equals",
      targetMethodDescriptor = "([BII[BII)Z")
  public static void
  arraysEqualsRange(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Boolean returnValue) {
    if (returnValue)
      return;
    byte[] first =
        Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]);
    byte[] second =
        Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]);
    TraceDataFlowNativeCallbacks.traceMemcmp(first, second, 1, hookId);
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare",
      targetMethodDescriptor = "([B[B)I")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays",
      targetMethod = "compareUnsigned", targetMethodDescriptor = "([B[B)I")
  public static void
  arraysCompare(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue == 0)
      return;
    byte[] first = (byte[]) arguments[0];
    byte[] second = (byte[]) arguments[1];
    TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays", targetMethod = "compare",
      targetMethodDescriptor = "([BII[BII)I")
  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Arrays",
      targetMethod = "compareUnsigned", targetMethodDescriptor = "([BII[BII)I")
  public static void
  arraysCompareRange(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Integer returnValue) {
    if (returnValue == 0)
      return;
    byte[] first =
        Arrays.copyOfRange((byte[]) arguments[0], (int) arguments[1], (int) arguments[2]);
    byte[] second =
        Arrays.copyOfRange((byte[]) arguments[3], (int) arguments[4], (int) arguments[5]);
    TraceDataFlowNativeCallbacks.traceMemcmp(first, second, returnValue, hookId);
  }

  // The maximal number of elements of a non-TreeMap Map that will be sorted and searched for the
  // key closest to the current lookup key in the mapGet hook.
  private static final int MAX_NUM_KEYS_TO_ENUMERATE = 100;

  @SuppressWarnings({"rawtypes", "unchecked"})
  @MethodHook(type = HookType.AFTER, targetClassName = "java.util.Map", targetMethod = "get")
  public static void mapGet(
      MethodHandle method, Object thisObject, Object[] arguments, int hookId, Object returnValue) {
    if (returnValue != null)
      return;
    if (arguments.length != 1) {
      return;
    }
    if (thisObject == null)
      return;
    final Map map = (Map) thisObject;
    if (map.size() == 0)
      return;
    final Object currentKey = arguments[0];
    if (currentKey == null)
      return;
    // Find two valid map keys that bracket currentKey.
    // This is a generalization of libFuzzer's __sanitizer_cov_trace_switch:
    // https://github.com/llvm/llvm-project/blob/318942de229beb3b2587df09e776a50327b5cef0/compiler-rt/lib/fuzzer/FuzzerTracePC.cpp#L564
    Object lowerBoundKey = null;
    Object upperBoundKey = null;
    try {
      if (map instanceof TreeMap) {
        final TreeMap treeMap = (TreeMap) map;
        try {
          lowerBoundKey = treeMap.floorKey(currentKey);
          upperBoundKey = treeMap.ceilingKey(currentKey);
        } catch (ClassCastException ignored) {
          // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
          // compared to the maps keys.
        }
      } else if (currentKey instanceof Comparable) {
        final Comparable comparableCurrentKey = (Comparable) currentKey;
        // Find two keys that bracket currentKey.
        // Note: This is not deterministic if map.size() > MAX_NUM_KEYS_TO_ENUMERATE.
        int enumeratedKeys = 0;
        for (Object validKey : map.keySet()) {
          if (!(validKey instanceof Comparable))
            continue;
          final Comparable comparableValidKey = (Comparable) validKey;
          // If the key sorts lower than the non-existing key, but higher than the current lower
          // bound, update the lower bound and vice versa for the upper bound.
          try {
            if (comparableValidKey.compareTo(comparableCurrentKey) < 0
                && (lowerBoundKey == null || comparableValidKey.compareTo(lowerBoundKey) > 0)) {
              lowerBoundKey = validKey;
            }
            if (comparableValidKey.compareTo(comparableCurrentKey) > 0
                && (upperBoundKey == null || comparableValidKey.compareTo(upperBoundKey) < 0)) {
              upperBoundKey = validKey;
            }
          } catch (ClassCastException ignored) {
            // Can be thrown by floorKey and ceilingKey if currentKey is of a type that can't be
            // compared to the maps keys.
          }
          if (enumeratedKeys++ > MAX_NUM_KEYS_TO_ENUMERATE)
            break;
        }
      }
    } catch (ConcurrentModificationException ignored) {
      // map was modified by another thread, skip this invocation
      return;
    }
    // Modify the hook ID so that compares against distinct valid keys are traced separately.
    if (lowerBoundKey != null) {
      TraceDataFlowNativeCallbacks.traceGenericCmp(
          currentKey, lowerBoundKey, hookId + lowerBoundKey.hashCode());
    }
    if (upperBoundKey != null) {
      TraceDataFlowNativeCallbacks.traceGenericCmp(
          currentKey, upperBoundKey, hookId + upperBoundKey.hashCode());
    }
  }

  @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
      targetMethod = "assertNotEquals",
      targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;)V")
  @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
      targetMethod = "assertNotEquals",
      targetMethodDescriptor = "(Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/String;)V")
  @MethodHook(type = HookType.AFTER, targetClassName = "org.junit.jupiter.api.Assertions",
      targetMethod = "assertNotEquals",
      targetMethodDescriptor =
          "(Ljava/lang/Object;Ljava/lang/Object;Ljava/util/function/Supplier;)V")
  public static void
  assertEquals(MethodHandle method, Object node, Object[] args, int hookId, Object alwaysNull) {
    if (args[0] != null && args[1] != null && args[0].getClass() == args[1].getClass()) {
      TraceDataFlowNativeCallbacks.traceGenericCmp(args[0], args[1], hookId);
    }
  }
}
