/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "gtest/gtest.h"

#include "berberis/base/bit_util.h"
#include "berberis/guest_abi/function_wrappers.h"
#include "berberis/guest_abi/guest_abi.h"
#include "berberis/guest_abi/guest_type.h"
#include "berberis/guest_state/guest_addr.h"
#include "berberis/guest_state/guest_state_opaque.h"
#include "berberis/runtime_primitives/host_code.h"
#include "berberis/runtime_primitives/host_function_wrapper_impl.h"
#include "berberis/test_utils/guest_exec_region.h"
#include "berberis/test_utils/translation_test.h"

namespace berberis {

namespace {

class HostFunctionWrapperTest : public TranslationTest {
 protected:
  // Note: we are using both Guest Function wrapper and Host Function Wrapper here.
  //
  // We don't call wrapped function via GuestCall — this would ensure that they match but this could
  // easily happen while they both would violate AAPCS.
  //
  // Instead we use simple functions (usually with just an address of host-wrapped function as a
  // sole argument) and use compiler-generated code which verifies our wrappers support code
  // generated by compiler.
  //
  // This is also important because certain facilities are not supported by GuestCall currently
  // (e.g. large structures like with mallinfo(3) or AAPCS-VFP like certain Vulkan functions).
  template <typename GuestResultType,
            GuestAbi::CallingConventionsVariant kCallingConventionsVariant = GuestAbi::kDefaultAbi,
            typename Func,
            typename... AdditionalParams>
  GuestResultType CallWrappedHostFunctionFromWrappedGuestFunction(GuestAddr guest_function,
                                                                  Func host_function,
                                                                  AdditionalParams... params) {
    WrapHostFunction<kCallingConventionsVariant>(host_function, "HostFunction");

    auto caller = WrapGuestFunction(
        bit_cast<GuestType<GuestResultType (*)(Func, AdditionalParams...)>>(guest_function),
        "GuestFunction");

    return caller(host_function, params...);
  }
};

void do_not_wrap_me() {}

void wrap_me() {}

void trampoline(HostCode, ProcessState*) {}

TEST_F(HostFunctionWrapperTest, Unwrap) {
  InitBerberis();

  EXPECT_TRUE(nullptr == UnwrapHostFunction(0));

  EXPECT_TRUE(nullptr == UnwrapHostFunction(ToGuestAddr(do_not_wrap_me)));

  WrapHostFunctionImpl(
      const_cast<HostCode>(reinterpret_cast<void*>(wrap_me)), trampoline, "wrap_me");
  EXPECT_TRUE(wrap_me == UnwrapHostFunction(ToGuestAddr(wrap_me)));
}

int sub(int x, int y) {
  return x - y;
}

TEST_F(HostFunctionWrapperTest, WrapTwoInt) {
  // int caller(int(*ptr)(int, int)) {
  //   return ptr(5, 7);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0xaa0003e2,  // mov x2, x0
      0x528000e1,  // mov w1, #0x7
      0x528000a0,  // mov w0, #0x5
      0xd61f0040   // br x2
  });

  EXPECT_EQ(-2, CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, sub));
}

float fsub(float x, float y) {
  return x - y;
}

TEST_F(HostFunctionWrapperTest, WrapTwoFloat) {
  // float foo(float (*ptr)(float, float)) {
  //    return ptr(5.0, 7.0);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0x1e239001,  // fmov s1, #7.000000000000000000e+00
      0x1e229000,  // fmov s0, #5.000000000000000000e+00
      0xd61f0000,  // br   x0
  });

  EXPECT_EQ(-2.0, CallWrappedHostFunctionFromWrappedGuestFunction<float>(pc, fsub));
}

int add(int x, int y) {
  return x + y;
}

int (*(add_sub_chooser)(int n))(int, int) {
  if (n == 1) return add;
  if (n == 2) return sub;
  return nullptr;
}

TEST_F(HostFunctionWrapperTest, WrapReturnedFunction) {
  // int caller(int (*(*ptr)(int))(int, int), int n) {
  //   return ptr(n)(5, 7);
  // }
  GuestAddr pc = MakeGuestExecRegion<uint32_t>({
      0xa9bf7bfd,  // stp x29, x30, [sp, #-16]!
      0xaa0003e2,  // mov x2, x0
      0x2a0103e0,  // mov w0, w1
      0x910003fd,  // mov x29, sp
      0xd63f0040,  // blr x2
      0xaa0003e2,  // mov x2, x0
      0xa8c17bfd,  // ldp x29, x30, [sp], #16
      0x528000e1,  // mov w1, #0x7
      0x528000a0,  // mov w0, #0x5
      0xd61f0040   // br  x2
  });

  EXPECT_EQ(12, CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, add_sub_chooser, 1));
  EXPECT_EQ(-2, CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, add_sub_chooser, 2));
  EXPECT_DEATH(CallWrappedHostFunctionFromWrappedGuestFunction<int>(pc, add_sub_chooser, 0), "");
}  // namespace

}  // namespace

}  // namespace berberis
