// Copyright 2023 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.

#include "pw_bluetooth_sapphire/internal/host/sm/pairing_channel.h"

#include <memory>

#include "lib/fit/function.h"
#include "pw_bluetooth_sapphire/internal/host/common/byte_buffer.h"
#include "pw_bluetooth_sapphire/internal/host/hci-spec/protocol.h"
#include "pw_bluetooth_sapphire/internal/host/hci/connection.h"
#include "pw_bluetooth_sapphire/internal/host/l2cap/mock_channel_test.h"
#include "pw_bluetooth_sapphire/internal/host/sm/smp.h"
#include "pw_bluetooth_sapphire/internal/host/sm/util.h"
#include "pw_bluetooth_sapphire/internal/host/testing/test_helpers.h"

namespace bt::sm {
namespace {

class FakeChannelHandler : public PairingChannel::Handler {
 public:
  FakeChannelHandler() : weak_self_(this) {}

  void OnRxBFrame(ByteBufferPtr data) override {
    last_rx_data_ = std::move(data);
    frames_received_++;
  }

  void OnChannelClosed() override { channel_closed_count_++; }

  ByteBuffer* last_rx_data() { return last_rx_data_.get(); }

  int frames_received() const { return frames_received_; }
  int channel_closed_count() const { return channel_closed_count_; }
  PairingChannel::Handler::WeakPtr as_weak_handler() {
    return weak_self_.GetWeakPtr();
  }

 private:
  ByteBufferPtr last_rx_data_ = nullptr;
  int frames_received_ = 0;
  int channel_closed_count_ = 0;
  WeakSelf<PairingChannel::Handler> weak_self_;
};

class PairingChannelTest : public l2cap::testing::MockChannelTest {
 protected:
  void SetUp() override { NewPairingChannel(); }

  void TearDown() override { sm_chan_ = nullptr; }

  void NewPairingChannel(bt::LinkType ll_type = bt::LinkType::kLE,
                         uint16_t mtu = kNoSecureConnectionsMtu) {
    l2cap::ChannelId cid = ll_type == bt::LinkType::kLE ? l2cap::kLESMPChannelId
                                                        : l2cap::kSMPChannelId;
    ChannelOptions options(cid, mtu);
    options.link_type = ll_type;
    fake_sm_chan_ = CreateFakeChannel(options);
    sm_chan_ = std::make_unique<PairingChannel>(
        fake_sm_chan_->GetWeakPtr(),
        fit::bind_member<&PairingChannelTest::ResetTimer>(this));
  }

  PairingChannel* sm_chan() { return sm_chan_.get(); }
  void set_timer_resetter(fit::closure timer_resetter) {
    timer_resetter_ = std::move(timer_resetter);
  }

 private:
  void ResetTimer() { timer_resetter_(); }

  l2cap::testing::FakeChannel::WeakPtr fake_sm_chan_;
  std::unique_ptr<PairingChannel> sm_chan_;
  fit::closure timer_resetter_ = []() {};
};

using PairingChannelDeathTest = PairingChannelTest;
TEST_F(PairingChannelDeathTest, L2capChannelMtuTooSmallDies) {
  ASSERT_DEATH_IF_SUPPORTED(
      NewPairingChannel(bt::LinkType::kLE, kNoSecureConnectionsMtu - 1),
      ".*max.*_sdu_size.*");
}

TEST_F(PairingChannelDeathTest, SendInvalidMessageDies) {
  // Tests that an invalid SMP code aborts the process
  EXPECT_DEATH_IF_SUPPORTED(
      sm_chan()->SendMessage(0xFF, ErrorCode::kUnspecifiedReason), ".*end.*");

  // Tests that a valid SMP code with a mismatched payload aborts the process
  EXPECT_DEATH_IF_SUPPORTED(
      sm_chan()->SendMessage(kPairingFailed, PairingRequestParams{}),
      ".*sizeof.*");
}

TEST_F(PairingChannelTest, SendMessageWorks) {
  PairingRandomValue kExpectedPayload = {1, 2, 3, 4, 5};
  StaticByteBuffer<util::PacketSize<PairingRandomValue>()> kExpectedPacket;
  PacketWriter w(kPairingRandom, &kExpectedPacket);
  *w.mutable_payload<PairingRandomValue>() = kExpectedPayload;
  bool timer_reset = false;
  set_timer_resetter([&]() { timer_reset = true; });
  EXPECT_PACKET_OUT(kExpectedPacket);
  sm_chan()->SendMessage(kPairingRandom, kExpectedPayload);
  RunUntilIdle();
  ASSERT_TRUE(timer_reset);
}

// This checks that PairingChannel doesn't crash when receiving events without a
// handler set.
TEST_F(PairingChannelTest, NoHandlerSetDataDropped) {
  ASSERT_TRUE(sm_chan());
  const StaticByteBuffer kSmPacket(kPairingFailed,
                                   ErrorCode::kPairingNotSupported);

  fake_chan()->Receive(kSmPacket);
  RunUntilIdle();

  fake_chan()->Close();
  RunUntilIdle();
}

TEST_F(PairingChannelTest, SetHandlerReceivesData) {
  ASSERT_TRUE(sm_chan());
  const StaticByteBuffer kSmPacket1(kPairingFailed,
                                    ErrorCode::kPairingNotSupported);
  const StaticByteBuffer kSmPacket2(kPairingFailed,
                                    ErrorCode::kConfirmValueFailed);
  FakeChannelHandler handler;
  sm_chan()->SetChannelHandler(handler.as_weak_handler());
  ASSERT_EQ(handler.last_rx_data(), nullptr);
  ASSERT_EQ(handler.frames_received(), 0);

  fake_chan()->Receive(kSmPacket1);
  RunUntilIdle();
  ASSERT_NE(handler.last_rx_data(), nullptr);
  EXPECT_TRUE(ContainersEqual(*handler.last_rx_data(), kSmPacket1));
  ASSERT_EQ(handler.frames_received(), 1);

  fake_chan()->Receive(kSmPacket2);
  RunUntilIdle();
  ASSERT_NE(handler.last_rx_data(), nullptr);
  EXPECT_TRUE(ContainersEqual(*handler.last_rx_data(), kSmPacket2));
  ASSERT_EQ(handler.frames_received(), 2);

  fake_chan()->Close();
  RunUntilIdle();
  ASSERT_EQ(handler.channel_closed_count(), 1);
}

TEST_F(PairingChannelTest, ChangeHandlerNewHandlerReceivesData) {
  ASSERT_TRUE(sm_chan());
  const StaticByteBuffer kSmPacket1(kPairingFailed,
                                    ErrorCode::kPairingNotSupported);
  const StaticByteBuffer kSmPacket2(kPairingFailed,
                                    ErrorCode::kConfirmValueFailed);
  FakeChannelHandler handler;
  sm_chan()->SetChannelHandler(handler.as_weak_handler());
  ASSERT_EQ(handler.last_rx_data(), nullptr);
  ASSERT_EQ(handler.frames_received(), 0);

  fake_chan()->Receive(kSmPacket1);
  RunUntilIdle();
  ASSERT_NE(handler.last_rx_data(), nullptr);
  EXPECT_TRUE(ContainersEqual(*handler.last_rx_data(), kSmPacket1));
  ASSERT_EQ(handler.frames_received(), 1);

  FakeChannelHandler new_handler;
  ASSERT_EQ(new_handler.last_rx_data(), nullptr);
  sm_chan()->SetChannelHandler(new_handler.as_weak_handler());
  fake_chan()->Receive(kSmPacket2);
  RunUntilIdle();
  ASSERT_NE(new_handler.last_rx_data(), nullptr);
  EXPECT_TRUE(ContainersEqual(*new_handler.last_rx_data(), kSmPacket2));
  ASSERT_EQ(new_handler.frames_received(), 1);
  // Check handler's data hasn't changed.
  EXPECT_TRUE(ContainersEqual(*handler.last_rx_data(), kSmPacket1));
  ASSERT_EQ(handler.frames_received(), 1);

  fake_chan()->Close();
  RunUntilIdle();
  ASSERT_EQ(new_handler.channel_closed_count(), 1);
  ASSERT_EQ(handler.channel_closed_count(), 0);
}
}  // namespace
}  // namespace bt::sm
