/*
 * Copyright (C) 2018 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef _DNS_DNSTLSSOCKET_H
#define _DNS_DNSTLSSOCKET_H

#include <openssl/ssl.h>
#include <future>
#include <mutex>

#include <android-base/thread_annotations.h>
#include <android-base/unique_fd.h>
#include <netdutils/Slice.h>
#include <netdutils/Status.h>

#include "DnsTlsServer.h"
#include "IDnsTlsSocket.h"
#include "LockedQueue.h"

namespace android {
namespace net {

class IDnsTlsSocketObserver;
class DnsTlsSessionCache;

// A class for managing a TLS socket that sends and receives messages in
// [length][value] format, with a 2-byte length (i.e. DNS-over-TCP format).
// This class is not aware of query-response pairing or anything else about DNS.
// For the observer:
// This class is not re-entrant: the observer is not permitted to wait for a call to query()
// or the destructor in a callback.  Doing so will result in deadlocks.
// This class may call the observer at any time after initialize(), until the destructor
// returns (but not after).
//
// Calls to IDnsTlsSocketObserver in a DnsTlsSocket life cycle:
//
//                                UNINITIALIZED
//                                      |
//                                      v
//                                 INITIALIZED
//                                      |
//                                      v
//                            +----CONNECTING------+
//            Handshake fails |                    | Handshake succeeds
//   (onClose() when          |                    |
//    mAsyncHandshake is set) |                    v
//                            |        +---> CONNECTED --+
//                            |        |           |     |
//                            |        +-----------+     | Idle timeout
//                            |   Send/Recv queries      | onClose()
//                            |   onResponse()           |
//                            |                          |
//                            |                          |
//                            +--> WAIT_FOR_DELETE <-----+
//
//
// TODO: Add onHandshakeFinished() for handshake results.
class DnsTlsSocket : public IDnsTlsSocket {
  public:
    enum class State {
        UNINITIALIZED,
        INITIALIZED,
        CONNECTING,
        CONNECTED,
        WAIT_FOR_DELETE,
    };

    DnsTlsSocket(const DnsTlsServer& server, unsigned mark,
                 IDnsTlsSocketObserver* _Nonnull observer, DnsTlsSessionCache* _Nonnull cache)
        : mMark(mark), mServer(server), mObserver(observer), mCache(cache) {}
    ~DnsTlsSocket();

    // Creates the SSL context for this session. Returns false on failure.
    // This method should be called after construction and before use of a DnsTlsSocket.
    // Only call this method once per DnsTlsSocket.
    bool initialize() EXCLUDES(mLock);

    // If async handshake is enabled, this function simply signals a handshake request, and the
    // handshake will be performed in the loop thread; otherwise, if async handshake is disabled,
    // this function performs the handshake and returns after the handshake finishes.
    bool startHandshake() EXCLUDES(mLock);

    // Send a query on the provided SSL socket.  |query| contains
    // the body of a query, not including the ID header. This function will typically return before
    // the query is actually sent.  If this function fails, DnsTlsSocketObserver will be
    // notified that the socket is closed.
    // Note that success here indicates successful sending, not receipt of a response.
    // Thread-safe.
    bool query(uint16_t id, const netdutils::Slice query) override EXCLUDES(mLock);

  private:
    // Lock to be held by the SSL event loop thread.  This is not normally in contention.
    std::mutex mLock;

    // Forwards queries and receives responses.  Blocks until the idle timeout.
    void loop() EXCLUDES(mLock);
    std::unique_ptr<std::thread> mLoopThread GUARDED_BY(mLock);

    // On success, sets mSslFd to a socket connected to mAddr (the
    // connection will likely be in progress if mProtocol is IPPROTO_TCP).
    // On error, returns the errno.
    netdutils::Status tcpConnect() REQUIRES(mLock);

    bssl::UniquePtr<SSL> prepareForSslConnect(int fd) REQUIRES(mLock);

    // Connect an SSL session on the provided socket.  If connection fails, closing the
    // socket remains the caller's responsibility.
    bssl::UniquePtr<SSL> sslConnect(int fd) REQUIRES(mLock);

    // Connect an SSL session on the provided socket. This is an interruptible version
    // which allows to terminate connection handshake any time.
    bssl::UniquePtr<SSL> sslConnectV2(int fd) REQUIRES(mLock);

    // Disconnect the SSL session and close the socket.
    void sslDisconnect() REQUIRES(mLock);

    // Writes a buffer to the socket.
    bool sslWrite(const netdutils::Slice buffer) REQUIRES(mLock);

    // Reads exactly the specified number of bytes from the socket, or fails.
    // Returns SSL_ERROR_NONE on success.
    // If |wait| is true, then this function always blocks.  Otherwise, it
    // will return SSL_ERROR_WANT_READ if there is no data from the server to read.
    int sslRead(const netdutils::Slice buffer, bool wait) REQUIRES(mLock);

    bool sendQuery(const std::vector<uint8_t>& buf) REQUIRES(mLock);

    // Read one DNS response. It can potentially block until reading the exact bytes of
    // the response.
    bool readResponse() REQUIRES(mLock);

    // It is only used for DNS-OVER-TLS internal test.
    bool setTestCaCertificate() REQUIRES(mLock);

    // Similar to query(), this function uses incrementEventFd to send a message to the
    // loop thread.  However, instead of incrementing the counter by one (indicating a
    // new query), it wraps the counter to negative, which we use to indicate a shutdown
    // request.
    void requestLoopShutdown() EXCLUDES(mLock);

    // This function sends a message to the loop thread by incrementing mEventFd.
    bool incrementEventFd(int64_t count) EXCLUDES(mLock);

    // Transition the state from expected state |from| to new state |to|.
    void transitionState(State from, State to) REQUIRES(mLock);

    // Queue of pending queries.  query() pushes items onto the queue and notifies
    // the loop thread by incrementing mEventFd.  loop() reads items off the queue.
    LockedQueue<std::vector<uint8_t>> mQueue;

    // eventfd socket used for notifying the SSL thread when queries are ready to send.
    // This socket acts similarly to an atomic counter, incremented by query() and cleared
    // by loop().  We have to use a socket because the SSL thread needs to wait in poll()
    // for input from either a remote server or a query thread.  Since eventfd does not have
    // EOF, we indicate a close request by setting the counter to a negative number.
    // This file descriptor is opened by initialize(), and closed implicitly after
    // destruction.
    // Note that: data starts being read from the eventfd when the state is CONNECTED.
    base::unique_fd mEventFd;

    // An eventfd used to listen to shutdown requests when the state is CONNECTING.
    // TODO: let |mEventFd| exclusively handle query requests, and let |mShutdownEvent| exclusively
    // handle shutdown requests.
    base::unique_fd mShutdownEvent;

    // SSL Socket fields.
    bssl::UniquePtr<SSL_CTX> mSslCtx GUARDED_BY(mLock);
    base::unique_fd mSslFd GUARDED_BY(mLock);
    bssl::UniquePtr<SSL> mSsl GUARDED_BY(mLock);
    static constexpr std::chrono::seconds kIdleTimeout = std::chrono::seconds(20);

    const unsigned mMark;  // Socket mark
    const DnsTlsServer mServer;
    IDnsTlsSocketObserver* _Nonnull const mObserver;
    DnsTlsSessionCache* _Nonnull const mCache;
    State mState GUARDED_BY(mLock) = State::UNINITIALIZED;

    // If true, defer the handshake to the loop thread; otherwise, run the handshake on caller's
    // thread (the call to startHandshake()).
    bool mAsyncHandshake GUARDED_BY(mLock) = false;

    // The time to wait for the attempt on connecting to the server.
    // Set the default value 127 seconds to be consistent with TCP connect timeout.
    // (presume net.ipv4.tcp_syn_retries = 6)
    static constexpr int kDotConnectTimeoutMs = 127 * 1000;
    int mConnectTimeoutMs;

    // For testing.
    friend class DnsTlsSocketTest;
};

}  // end of namespace net
}  // end of namespace android

#endif  // _DNS_DNSTLSSOCKET_H
