/*
 * Copyright (C) 2017 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.googlecode.android_scripting;

import com.google.common.collect.Lists;

import java.io.IOException;
import java.net.BindException;
import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.net.UnknownHostException;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;

/**
 * A simple server.
 */
public abstract class SimpleServer {
    private final CopyOnWriteArrayList<ConnectionThread> mConnectionThreads =
            new CopyOnWriteArrayList<>();
    private final List<SimpleServerObserver> mObservers = Lists.newArrayList();
    private volatile boolean mStopServer = false;
    private ServerSocket mServer;
    private Thread mServerThread;

    /**
     * An interface for accessing SimpleServer events.
     */
    public interface SimpleServerObserver {
        /** The function to be called when a ConnectionThread is established.*/
        void onConnect();
        /** The function to be called when a ConnectionThread disconnects.*/
        void onDisconnect();
    }

    /** An abstract method for handling non-RPC connections. */
    protected abstract void handleConnection(Socket socket) throws Exception;

    /**
     * Adds an observer.
     */
    public void addObserver(SimpleServerObserver observer) {
        mObservers.add(observer);
    }

    /**
     * Removes an observer.
     */
    public void removeObserver(SimpleServerObserver observer) {
        mObservers.remove(observer);
    }

    /** Notifies all subscribers when a new ConnectionThread has been created.
     *
     * This applies to both newly instantiated sessions and continued sessions.
     */
    private void notifyOnConnect() {
        for (SimpleServerObserver observer : mObservers) {
            observer.onConnect();
        }
    }

    /** Notifies all subscribers when a ConnectionThread has been terminated. */
    private void notifyOnDisconnect() {
        for (SimpleServerObserver observer : mObservers) {
            observer.onDisconnect();
        }
    }

    /** An implementation of a thread that holds data about its server connection status. */
    private final class ConnectionThread extends Thread {
        /** The socket used for communication. */
        private final Socket mmSocket;

        private ConnectionThread(Socket socket) {
            setName("SimpleServer ConnectionThread " + getId());
            mmSocket = socket;
        }

        @Override
        public void run() {
            Log.v("Server thread " + getId() + " started.");
            try {
                handleConnection(mmSocket);
            } catch (Exception e) {
                if (!mStopServer) {
                    Log.e("Server error.", e);
                }
            } finally {
                close();
                mConnectionThreads.remove(this);
                notifyOnDisconnect();
                Log.v("Server thread " + getId() + " stopped.");
            }
        }

        private void close() {
            if (mmSocket != null) {
                try {
                    mmSocket.close();
                } catch (IOException e) {
                    Log.e(e.getMessage(), e);
                }
            }
        }
    }

    /**
     * Returns the number of active connections to this server.
     */
    public int getNumberOfConnections() {
        return mConnectionThreads.size();
    }

    /**
     * Returns the private InetAddress
     * @return the private InetAddress
     * @throws UnknownHostException If unable to resolve localhost during fallback.
     * @throws SocketException if an IOError occurs while querying the network interfaces.
     */
    public static InetAddress getPrivateInetAddress() throws UnknownHostException, SocketException {
        InetAddress candidate = null;
        Enumeration<NetworkInterface> nets = NetworkInterface.getNetworkInterfaces();
        for (NetworkInterface netint : Collections.list(nets)) {
            if (!netint.isLoopback() || !netint.isUp()) { // Ignore if localhost or not active
                continue;
            }
            Enumeration<InetAddress> addresses = netint.getInetAddresses();
            for (InetAddress address : Collections.list(addresses)) {
                if (address instanceof Inet4Address) {
                    Log.d("local address " + address);
                    return address; // Prefer ipv4
                }
                candidate = address; // Probably an ipv6
            }
        }
        if (candidate != null) {
            return candidate; // return ipv6 address if no suitable ipv6
        }
        return InetAddress.getLocalHost(); // No damn matches. Give up, return local host.
    }

    /**
     * Returns the public InetAddress
     * @return the private InetAddress
     * @throws UnknownHostException If unable to resolve localhost during fallback.
     * @throws SocketException if an IOError occurs while querying the network interfaces.
     */
    public static InetAddress getPublicInetAddress() throws UnknownHostException, SocketException {
        InetAddress candidate = null;
        Enumeration<NetworkInterface> nets = NetworkInterface.getNetworkInterfaces();
        for (NetworkInterface netint : Collections.list(nets)) {
            // TODO(markdr): The only diff between this and above fn is the ! on the line below.
            //               Merge these two functions.
            if (netint.isLoopback() || !netint.isUp()) { // Ignore if localhost or not active
                continue;
            }
            Enumeration<InetAddress> addresses = netint.getInetAddresses();
            for (InetAddress address : Collections.list(addresses)) {
                if (address instanceof Inet4Address) {
                    return address; // Prefer ipv4
                }
                candidate = address; // Probably an ipv6
            }
        }
        if (candidate != null) {
            return candidate; // return ipv6 address if no suitable ipv6
        }
        return InetAddress.getLocalHost(); // No damn matches. Give up, return local host.
    }

    /**
     * Starts the RPC server bound to the localhost address.
     *
     * @param port the port to bind to or 0 to pick any unused port
     * @return the port that the server is bound to
     * @throws IOException
     */
    public InetSocketAddress startLocal(int port) {
        InetAddress address;
        try {
            // address = InetAddress.getLocalHost();
            address = getPrivateInetAddress();
            mServer = new ServerSocket(port, 5, address);
        } catch (BindException e) {
            Log.e("Port " + port + " already in use.");
            try {
                address = getPrivateInetAddress();
                mServer = new ServerSocket(0, 5, address);
            } catch (IOException e1) {
                e1.printStackTrace();
                return null;
            }
        } catch (Exception e) {
            Log.e("Failed to start server.", e);
            return null;
        }
        int boundPort = start();
        return InetSocketAddress.createUnresolved(mServer.getInetAddress().getHostAddress(),
                boundPort);
    }

    /**
     * Starts the RPC server bound to the public facing address.
     *
     * @param port the port to bind to or 0 to pick any unused port
     * @return the port that the server is bound to
     */
    public InetSocketAddress startPublic(int port) {
        InetAddress address;
        try {
            // address = getPublicInetAddress();
            address = null;
            mServer = new ServerSocket(port, 5 /* backlog */, address);
        } catch (Exception e) {
            Log.e("Failed to start server.", e);
            return null;
        }
        int boundPort = start();
        return InetSocketAddress.createUnresolved(mServer.getInetAddress().getHostAddress(),
                boundPort);
    }

    /**
     * Starts the RPC server bound to all interfaces.
     *
     * @param port the port to bind to or 0 to pick any unused port
     * @return the port that the server is bound to
     */
    public InetSocketAddress startAllInterfaces(int port) {
        try {
            mServer = new ServerSocket(port, 5 /* backlog */);
        } catch (Exception e) {
            Log.e("Failed to start server.", e);
            return null;
        }
        int boundPort = start();
        return InetSocketAddress.createUnresolved(mServer.getInetAddress().getHostAddress(),
                boundPort);
    }

    private int start() {
        mServerThread = new Thread(() -> {
            while (!mStopServer) {
                try {
                    Socket sock = mServer.accept();
                    if (!mStopServer) {
                        startConnectionThread(sock);
                    } else {
                        sock.close();
                    }
                } catch (IOException e) {
                    if (!mStopServer) {
                        Log.e("Failed to accept connection.", e);
                    }
                }
            }
        });
        mServerThread.start();
        Log.v("Bound to " + mServer.getInetAddress());
        return mServer.getLocalPort();
    }

    protected void startConnectionThread(final Socket sock) {
        ConnectionThread networkThread = new ConnectionThread(sock);
        mConnectionThreads.add(networkThread);
        networkThread.start();
        notifyOnConnect();
    }

    /** Closes the server, preventing new connections from being added. */
    public void shutdown() {
        // Stop listening on the server socket to ensure that
        // beyond this point there are no incoming requests.
        mStopServer = true;
        try {
            mServer.close();
        } catch (IOException e) {
            Log.e("Failed to close server socket.", e);
        }
        // Since the server is not running, the mNetworkThreads set can only
        // shrink from this point onward. We can just stop all of the running helper
        // threads. In the worst case, one of the running threads will already have
        // shut down. Since this is a CopyOnWriteList, we don't have to worry about
        // concurrency issues while iterating over the set of threads.
        for (ConnectionThread connectionThread : mConnectionThreads) {
            connectionThread.close();
        }
        for (SimpleServerObserver observer : mObservers) {
            removeObserver(observer);
        }
    }
}
