/*
 * Copyright (C) 2023 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "gtest/gtest.h"

#include "berberis/base/bit_util.h"
#include "berberis/guest_abi/function_wrappers.h"
#include "berberis/guest_state/guest_addr.h"
#include "berberis/guest_state/guest_state.h"
#include "berberis/test_utils/guest_exec_region.h"
#include "berberis/test_utils/translation_test.h"

namespace berberis {

namespace {

class HostFunctionWrapperTest : public TranslationTest {
 protected:
  // Note: we are using both guest function wrapper and host function wrapper here.
  //
  // We don't call wrapped function via GuestCall — this would ensure that they match but this could
  // easily happen while they both would violate the calling convention.
  //
  // Instead we use simple functions (usually with just an address of host-wrapped function as a
  // sole argument) and use compiler-generated code which verifies our wrappers support code
  // generated by compiler.
  //
  // This is also important because certain facilities are not supported by GuestCall currently
  // (e.g. large structures like with mallinfo(3) or floating point arguments like certain Vulkan
  // functions).
  template <typename GuestResultType,
            GuestAbi::CallingConventionsVariant kCallingConventionsVariant = GuestAbi::kDefaultAbi,
            typename Func,
            typename... AdditionalParams>
  GuestResultType CallWrappedHostFunctionFromWrappedGuestFunction(GuestAddr guest_function,
                                                                  Func host_function,
                                                                  AdditionalParams... params) {
    WrapHostFunction<kCallingConventionsVariant>(host_function, "HostFunction");

    auto caller = WrapGuestFunction(
        bit_cast<GuestType<GuestResultType (*)(Func, AdditionalParams...)>>(guest_function),
        "GuestFunction");

    return caller(host_function, params...);
  }
};

void do_not_wrap_me() {}

void wrap_me() {}

void trampoline(HostCode, ProcessState*) {}

TEST_F(HostFunctionWrapperTest, Unwrap) {
  EXPECT_EQ(UnwrapHostFunction(0), nullptr);

  EXPECT_EQ(UnwrapHostFunction(ToGuestAddr(do_not_wrap_me)), nullptr);

  WrapHostFunctionImpl(
      const_cast<HostCode>(reinterpret_cast<void*>(wrap_me)), trampoline, "wrap_me");
  EXPECT_EQ(UnwrapHostFunction(ToGuestAddr(wrap_me)), wrap_me);
}

int sub(int x, int y) {
  return x - y;
}

TEST_F(HostFunctionWrapperTest, WrapTwoInt) {
  // int caller(int (*ptr)(int, int)) {
  //   return ptr(5, 7);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0x00050313,  // mv t1, a0
      0x00500513,  // li a0, #5
      0x00700593,  // li a1, #7
      0x00030067   // jr tl
  });

  EXPECT_EQ(CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, sub), -2);
}

float fsub(float x, float y) {
  return x - y;
}

TEST_F(HostFunctionWrapperTest, WrapTwoFloat) {
  // float foo(float (*ptr)(float, float)) {
  //    return ptr(5.0f, 7.0f);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0x00500593,  // li a1, #5
      0xd025f553,  // fcvt.s.l fa0, a1
      0x00700593,  // li a1, #7
      0xd025f5d3,  // fcvt.s.l fa1, a1
      0x00050067,  // jr a0
  });

  EXPECT_FLOAT_EQ(
      (CallWrappedHostFunctionFromWrappedGuestFunction<float, GuestAbi::kLp64d>(pc, fsub)), -2.0f);
}

double dsub(double x, double y) {
  return x - y;
}

TEST_F(HostFunctionWrapperTest, WrapTwoDouble) {
  // double foo(double (*ptr)(double, double)) {
  //    return ptr(5.0, 7.0);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0x00500593,  // li a1, #5
      0xd225f553,  // fcvt.d.l fa0, a1
      0x00700593,  // li a1, #7
      0xd225f5d3,  // fcvt.d.l fa1, a1
      0x00050067,  // jr a0
  });

  EXPECT_DOUBLE_EQ(
      (CallWrappedHostFunctionFromWrappedGuestFunction<double, GuestAbi::kLp64d>(pc, dsub)), -2.0);
}

int add(int x, int y) {
  return x + y;
}

int (*(add_sub_chooser)(int n))(int, int) {
  if (n == 1) return add;
  if (n == 2) return sub;
  return nullptr;
}

TEST_F(HostFunctionWrapperTest, WrapReturnedFunction) {
  // int caller(int (*(*ptr)(int))(int, int), int n) {
  //   return ptr(n)(5, 7);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0xff010113,  // addi sp, sp, #-16
      0x00113423,  // sd ra, 8(sp)
      0x00050613,  // mv a2, a0
      0x00058513,  // mv a0, a1
      0x000600e7,  // jalr a2
      0x00050313,  // mv t1, a0
      0x00500513,  // li a0, #5
      0x00700593,  // li a1, #7
      0x00813083,  // ld ra, 8(sp)
      0x01010113,  // addi sp, sp, #16
      0x00030067,  // jr tl
  });

  EXPECT_EQ(CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, add_sub_chooser, 1), 12);
  EXPECT_EQ(CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, add_sub_chooser, 2), -2);
  EXPECT_DEATH(CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, add_sub_chooser, 0), "");
}

}  // namespace

}  // namespace berberis
