// Copyright 2020 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#include <optional>

#include "pw_rpc/internal/test_utils.h"
#include "pw_rpc/nanopb/client_reader_writer.h"
#include "pw_rpc_nanopb_private/internal_test_utils.h"
#include "pw_rpc_test_protos/test.pb.h"
#include "pw_unit_test/framework.h"

PW_MODIFY_DIAGNOSTICS_PUSH();
PW_MODIFY_DIAGNOSTIC(ignored, "-Wmissing-field-initializers");

namespace pw::rpc {
namespace {

using internal::ClientContextForTest;

constexpr uint32_t kServiceId = 16;
constexpr uint32_t kUnaryMethodId = 111;
constexpr uint32_t kServerStreamingMethodId = 112;

class FakeGeneratedServiceClient {
 public:
  static NanopbUnaryReceiver<pw_rpc_test_TestResponse> TestUnaryRpc(
      Client& client,
      uint32_t channel_id,
      const pw_rpc_test_TestRequest& request,
      Function<void(const pw_rpc_test_TestResponse&, Status)> on_response,
      Function<void(Status)> on_error = nullptr) {
    return pw::rpc::internal::
        NanopbUnaryResponseClientCall<pw_rpc_test_TestResponse>::Start<
            pw::rpc::NanopbUnaryReceiver<pw_rpc_test_TestResponse>>(
            client,
            channel_id,
            kServiceId,
            kUnaryMethodId,
            internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
                                         pw_rpc_test_TestResponse_fields>,
            std::move(on_response),
            std::move(on_error),
            request);
  }

  static NanopbUnaryReceiver<pw_rpc_test_TestResponse> TestAnotherUnaryRpc(
      Client& client,
      uint32_t channel_id,
      const pw_rpc_test_TestRequest& request,
      Function<void(const pw_rpc_test_TestResponse&, Status)> on_response,
      Function<void(Status)> on_error = nullptr) {
    return pw::rpc::internal::
        NanopbUnaryResponseClientCall<pw_rpc_test_TestResponse>::Start<
            pw::rpc::NanopbUnaryReceiver<pw_rpc_test_TestResponse>>(
            client,
            channel_id,
            kServiceId,
            kUnaryMethodId,
            internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
                                         pw_rpc_test_TestResponse_fields>,
            std::move(on_response),
            std::move(on_error),
            request);
  }

  static NanopbClientReader<pw_rpc_test_TestStreamResponse> TestServerStreamRpc(
      Client& client,
      uint32_t channel_id,
      const pw_rpc_test_TestRequest& request,
      Function<void(const pw_rpc_test_TestStreamResponse&)> on_response,
      Function<void(Status)> on_stream_end,
      Function<void(Status)> on_error = nullptr) {
    return pw::rpc::internal::
        NanopbStreamResponseClientCall<pw_rpc_test_TestStreamResponse>::Start<
            pw::rpc::NanopbClientReader<pw_rpc_test_TestStreamResponse>>(
            client,
            channel_id,
            kServiceId,
            kServerStreamingMethodId,
            internal::kNanopbMethodSerde<pw_rpc_test_TestRequest_fields,
                                         pw_rpc_test_TestStreamResponse_fields>,
            std::move(on_response),
            std::move(on_stream_end),
            std::move(on_error),
            request);
  }
};

TEST(NanopbClientCall, Unary_SendsRequestPacket) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      nullptr);

  EXPECT_EQ(context.output().total_packets(), 1u);
  auto packet = context.output().last_packet();
  EXPECT_EQ(packet.channel_id(), context.channel().id());
  EXPECT_EQ(packet.service_id(), kServiceId);
  EXPECT_EQ(packet.method_id(), kUnaryMethodId);

  PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
  EXPECT_EQ(sent_proto.integer, 123);
}

class UnaryClientCall : public ::testing::Test {
 protected:
  std::optional<Status> last_status_;
  std::optional<Status> last_error_;
  int responses_received_ = 0;
  int last_response_value_ = 0;
};

TEST_F(UnaryClientCall, InvokesCallbackOnValidResponse) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      [this](const pw_rpc_test_TestResponse& response, Status status) {
        ++responses_received_;
        last_status_ = status;
        last_response_value_ = response.value;
      });

  PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
  EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));

  ASSERT_EQ(responses_received_, 1);
  EXPECT_EQ(last_status_, OkStatus());
  EXPECT_EQ(last_response_value_, 42);
}

TEST_F(UnaryClientCall, DoesNothingOnNullCallback) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      nullptr);

  PW_ENCODE_PB(pw_rpc_test_TestResponse, response, .value = 42);
  EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), response));

  ASSERT_EQ(responses_received_, 0);
}

TEST_F(UnaryClientCall, InvokesErrorCallbackOnInvalidResponse) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      [this](const pw_rpc_test_TestResponse& response, Status status) {
        ++responses_received_;
        last_status_ = status;
        last_response_value_ = response.value;
      },
      [this](Status status) { last_error_ = status; });

  constexpr std::byte bad_payload[]{
      std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
  EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));

  EXPECT_EQ(responses_received_, 0);
  ASSERT_TRUE(last_error_.has_value());
  EXPECT_EQ(last_error_, Status::DataLoss());
}

TEST_F(UnaryClientCall, InvokesErrorCallbackOnServerError) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      [this](const pw_rpc_test_TestResponse& response, Status status) {
        ++responses_received_;
        last_status_ = status;
        last_response_value_ = response.value;
      },
      [this](Status status) { last_error_ = status; });

  EXPECT_EQ(OkStatus(),
            context.SendPacket(internal::pwpb::PacketType::SERVER_ERROR,
                               Status::NotFound()));

  EXPECT_EQ(responses_received_, 0);
  EXPECT_EQ(last_error_, Status::NotFound());
}

TEST_F(UnaryClientCall, DoesNothingOnErrorWithoutCallback) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      [this](const pw_rpc_test_TestResponse& response, Status status) {
        ++responses_received_;
        last_status_ = status;
        last_response_value_ = response.value;
      });

  constexpr std::byte bad_payload[]{
      std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
  EXPECT_EQ(OkStatus(), context.SendResponse(OkStatus(), bad_payload));

  EXPECT_EQ(responses_received_, 0);
}

TEST_F(UnaryClientCall, OnlyReceivesOneResponse) {
  ClientContextForTest context;

  auto call = FakeGeneratedServiceClient::TestUnaryRpc(
      context.client(),
      context.channel().id(),
      {.integer = 123, .status_code = 0},
      [this](const pw_rpc_test_TestResponse& response, Status status) {
        ++responses_received_;
        last_status_ = status;
        last_response_value_ = response.value;
      });

  PW_ENCODE_PB(pw_rpc_test_TestResponse, r1, .value = 42);
  EXPECT_EQ(OkStatus(), context.SendResponse(Status::Unimplemented(), r1));
  PW_ENCODE_PB(pw_rpc_test_TestResponse, r2, .value = 44);
  EXPECT_EQ(OkStatus(), context.SendResponse(Status::OutOfRange(), r2));
  PW_ENCODE_PB(pw_rpc_test_TestResponse, r3, .value = 46);
  EXPECT_EQ(OkStatus(), context.SendResponse(Status::Internal(), r3));

  EXPECT_EQ(responses_received_, 1);
  EXPECT_EQ(last_status_, Status::Unimplemented());
  EXPECT_EQ(last_response_value_, 42);
}

class ServerStreamingClientCall : public ::testing::Test {
 protected:
  std::optional<Status> stream_status_;
  std::optional<Status> rpc_error_;
  int responses_received_ = 0;
  int last_response_number_ = 0;
};

TEST_F(ServerStreamingClientCall, SendsRequestPacket) {
  ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;

  auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
      context.client(),
      context.channel().id(),
      {.integer = 71, .status_code = 0},
      nullptr,
      nullptr);

  EXPECT_EQ(context.output().total_packets(), 1u);
  auto packet = context.output().last_packet();
  EXPECT_EQ(packet.channel_id(), context.channel().id());
  EXPECT_EQ(packet.service_id(), kServiceId);
  EXPECT_EQ(packet.method_id(), kServerStreamingMethodId);

  PW_DECODE_PB(pw_rpc_test_TestRequest, sent_proto, packet.payload());
  EXPECT_EQ(sent_proto.integer, 71);
}

TEST_F(ServerStreamingClientCall, InvokesCallbackOnValidResponse) {
  ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;

  auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
      context.client(),
      context.channel().id(),
      {.integer = 71, .status_code = 0},
      [this](const pw_rpc_test_TestStreamResponse& response) {
        ++responses_received_;
        last_response_number_ = response.number;
      },
      [this](Status status) { stream_status_ = status; });

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
  EXPECT_TRUE(call.active());
  EXPECT_EQ(responses_received_, 1);
  EXPECT_EQ(last_response_number_, 11);

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
  EXPECT_TRUE(call.active());
  EXPECT_EQ(responses_received_, 2);
  EXPECT_EQ(last_response_number_, 22);

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
  EXPECT_TRUE(call.active());
  EXPECT_EQ(responses_received_, 3);
  EXPECT_EQ(last_response_number_, 33);
}

TEST_F(ServerStreamingClientCall, InvokesStreamEndOnFinish) {
  ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;

  auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
      context.client(),
      context.channel().id(),
      {.integer = 71, .status_code = 0},
      [this](const pw_rpc_test_TestStreamResponse& response) {
        ++responses_received_;
        last_response_number_ = response.number;
      },
      [this](Status status) { stream_status_ = status; });

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
  EXPECT_TRUE(call.active());

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r2, .chunk = {}, .number = 22u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r2));
  EXPECT_TRUE(call.active());

  // Close the stream.
  EXPECT_EQ(OkStatus(), context.SendResponse(Status::NotFound()));

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r3, .chunk = {}, .number = 33u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r3));
  EXPECT_FALSE(call.active());

  EXPECT_EQ(responses_received_, 2);
}

TEST_F(ServerStreamingClientCall, ParseErrorTerminatesCallWithDataLoss) {
  ClientContextForTest<128, 99, kServiceId, kServerStreamingMethodId> context;

  auto call = FakeGeneratedServiceClient::TestServerStreamRpc(
      context.client(),
      context.channel().id(),
      {.integer = 71, .status_code = 0},
      [this](const pw_rpc_test_TestStreamResponse& response) {
        ++responses_received_;
        last_response_number_ = response.number;
      },
      nullptr,
      [this](Status error) { rpc_error_ = error; });

  PW_ENCODE_PB(pw_rpc_test_TestStreamResponse, r1, .chunk = {}, .number = 11u);
  EXPECT_EQ(OkStatus(), context.SendServerStream(r1));
  EXPECT_TRUE(call.active());
  EXPECT_EQ(responses_received_, 1);
  EXPECT_EQ(last_response_number_, 11);

  constexpr std::byte bad_payload[]{
      std::byte{0xab}, std::byte{0xcd}, std::byte{0xef}};
  EXPECT_EQ(OkStatus(), context.SendServerStream(bad_payload));
  EXPECT_FALSE(call.active());
  EXPECT_EQ(responses_received_, 1);
  EXPECT_EQ(rpc_error_, Status::DataLoss());
}

}  // namespace
}  // namespace pw::rpc

PW_MODIFY_DIAGNOSTICS_POP();
