/*
 * Copyright (C) 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.rkpdapp.unittest;

import static com.google.common.truth.Truth.assertThat;
import static com.google.common.truth.Truth.assertWithMessage;

import android.content.Context;
import android.content.pm.ApplicationInfo;
import android.content.pm.PackageManager;
import android.util.Base64;

import androidx.test.core.app.ApplicationProvider;

import com.android.rkpdapp.GeekResponse;
import com.android.rkpdapp.RkpdException;
import com.android.rkpdapp.interfaces.ServerInterface;
import com.android.rkpdapp.metrics.ProvisioningAttempt;
import com.android.rkpdapp.testutil.FakeRkpServer;
import com.android.rkpdapp.utils.CborUtils;
import com.android.rkpdapp.utils.Settings;

import org.junit.After;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
import org.mockito.Mockito;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;

public class ServerInterfaceTest {
    private static final Duration TIME_TO_REFRESH_HOURS = Duration.ofHours(2);
    private static Context sContext;
    private ServerInterface mServerInterface;

    @BeforeClass
    public static void init() {
        sContext = Mockito.spy(ApplicationProvider.getApplicationContext());
    }

    @Before
    public void setUp() {
        Settings.clearPreferences(sContext);
        mServerInterface = new ServerInterface(sContext, false);
        Utils.mockConnectivityState(sContext, Utils.ConnectivityState.CONNECTED);
    }

    @After
    public void tearDown() {
        Settings.clearPreferences(sContext);
        Mockito.reset(sContext);
    }

    @Test
    public void testRetryOnServerFailure() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(FakeRkpServer.Response.INTERNAL_ERROR,
                FakeRkpServer.Response.INTERNAL_ERROR)) {
            Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            Settings.setMaxRequestTime(sContext, 100);
            GeekResponse ignored = mServerInterface.fetchGeek(
                    ProvisioningAttempt.createScheduledAttemptMetrics(sContext));
            assertWithMessage("Expected RkpdException.").fail();
        } catch (RkpdException e) {
            assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.HTTP_SERVER_ERROR);
            assertThat(e).hasMessageThat().contains("HTTP error status encountered");
        }
    }

    @Test
    public void testFetchGeekRkpDisabled() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_RKP_DISABLED,
                FakeRkpServer.Response.INTERNAL_ERROR)) {
            Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            GeekResponse response = mServerInterface.fetchGeek(
                    ProvisioningAttempt.createScheduledAttemptMetrics(sContext));

            assertThat(response.numExtraAttestationKeys).isEqualTo(0);
            assertThat(response.getChallenge()).isNotNull();
            assertThat(response.getGeekChain(2)).isNotNull();
        }
    }

    @Test
    public void testFetchGeekRkpEnabled() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
            Settings.setDeviceConfig(sContext, 1 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            GeekResponse response = mServerInterface.fetchGeek(
                    ProvisioningAttempt.createScheduledAttemptMetrics(sContext));

            assertThat(response.numExtraAttestationKeys).isEqualTo(20);
            assertThat(response.getChallenge()).isNotNull();
            byte[] challenge = Base64.decode("AAABgEg1zGsBILStY/1VNI7st0AG9x2S/tba+H4=",
                    Base64.DEFAULT);
            assertThat(response.getChallenge()).isEqualTo(challenge);
            byte[] ed25519GeekChain = Base64.decode(
                    "g4RDoQEnoFgqpAEBAycgBiFYIJm57t1e5FL2hcZMYtw+YatXS"
                            + "H11NymtdoAy0rPLY1jZWEAeIghLpLekyNdOAw7+uK8UTKc7b6XN3Np5xitk"
                            + "/pk5r3bngPpmAIUNB5gqrJFcpyUUSQY0dcqKJ3rZ41pJ6wIDhEOhASegWCqk"
                            + "AQEDJyAGIVgg6i+FDp5qDFz3vdn6KDK/2lXpIKJRA8kDkxjOoBUp7NFYQIJr"
                            + "x12mNle3x3ESrRzCarMsIyrdFDDLghS2icXTHjG7uFAhSklNupEMbzNNg7xY"
                            + "Ky6E28VZD5hh4sHqifLQrgSEQ6EBJ6BYTqUBAQJYIG+S0QRtcdinjojY0VaB"
                            + "X5bReIPmMBuH7b8g0Uo7/mouAzgYIAQhWCC2XRxLmoM6nbUVWTehJvsP3+ec"
                            + "rAHVpOzIOikAiFglOVhAgLKf0DKenUr+sCXywtIiaEbGILCq6BasZKFFg5vM"
                            + "SVQlf6sWBVPwvTWT88a7WU5e+d4hBxSjtqSji4+Clpa6Aw==",
                    Base64.DEFAULT);
            byte[] p256GeekChain = Base64.decode(
                    "g4RDoQEmoFhNpQECAyYgASFYIPcUituX9MxT79JkEcTjdR9mH6Rx"
                            + "DGzP+glGgHSHVPKtIlggXn9b9uzk9hnM/xM3/Q+hyJPbGAZ2xF3m12p3hsMtr49YQC"
                            + "+XjkL7vgctlUeFR5NAsB/Um0ekxESp8qEHhxDHn8sR9L+f6Dvg5zRMFfx7w34zBfTR"
                            + "NDztAgRgehXgedOK/ySEQ6EBJqBYTaUBAgMmIAEhWCBRgKzPj5aM7A9Q4akbt5CGNI"
                            + "vjw6xlAk209jEOCEYyOSJYIFTrlJ3+trTkczolTi8fnZ29+mbBEYvploxD5DD22nar"
                            + "WECYOPs0OmXbc5ixJ6IVdPK+BueNIk7d8L/CAXTEtylrJBy12NJm+kTv9TAsBHTt6M"
                            + "Zg2s6fVlcndCHT3pOP47jNhEOhASagWHGmAQICWCCDn/j9EBwSn5JBx1uN5E70GROa"
                            + "xxttpw6V8mRTXacdwQM4GCABIVggFqRSEmOzhlZQ2N/yoKh9vNlup2hg6oxc8ZPllx"
                            + "kNrN4iWCCJvsxsP16wOTSvl7o40RYdocwdZNOMSE74coEbOz4x7lhA+trPLaulMAxz"
                            + "xeWrSZJZYET6xPIz5QSybBlk6RzjZDs0hgBlLfXdr6oBya+DyU74WpToZZNR4xgeOY"
                            + "CnaUszzQ==",
                    Base64.DEFAULT);
            assertThat(response.getGeekChain(CborUtils.EC_CURVE_25519)).isEqualTo(ed25519GeekChain);
            assertThat(response.getGeekChain(CborUtils.EC_CURVE_P256)).isEqualTo(p256GeekChain);
        }
    }

    @Test
    public void testFetchKeyAndUpdate() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            mServerInterface.fetchGeekAndUpdate(
                    ProvisioningAttempt.createScheduledAttemptMetrics(sContext));

            assertThat(Settings.getExtraSignedKeysAvailable(sContext)).isEqualTo(20);
            assertThat(Settings.getExpiringBy(sContext)).isEqualTo(Duration.ofHours(72));
        }
    }

    @Test
    public void testRequestSignedCertUnregistered() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_DEVICE_UNREGISTERED)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
                    sContext);
            mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
            assertWithMessage("Should fail due to unregistered device.").fail();
        } catch (RkpdException e) {
            assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.DEVICE_NOT_REGISTERED);
        }
    }

    @Test
    public void testRequestSignedCertClientError() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_USER_UNAUTHORIZED)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            Settings.setMaxRequestTime(sContext, 100);
            ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
                    sContext);
            mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
            assertWithMessage("Should fail due to client error.").fail();
        } catch (RkpdException e) {
            assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.HTTP_CLIENT_ERROR);
        }
    }

    @Test
    public void testRequestSignedCertCborError() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_OK_INVALID_CBOR)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
                    sContext);
            mServerInterface.requestSignedCertificates(new byte[0], new byte[0], metrics);
            assertWithMessage("Should fail due to invalid cbor.").fail();
        } catch (RkpdException e) {
            assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.INTERNAL_ERROR);
            assertThat(e).hasMessageThat().isEqualTo("Response failed to parse.");
        }
    }

    @Test
    public void testRequestSignedCertValid() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());
            ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
                    sContext);
            List<byte[]> certChains = mServerInterface.requestSignedCertificates(new byte[0],
                    new byte[0], metrics);
            assertThat(certChains).isEmpty();
            assertThat(certChains).isNotNull();
        }
    }

    @Test
    public void testDataBudgetEmptyFetchGeekNetworkConnected() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());

            // Check the data budget in order to initialize a rolling window.
            assertThat(Settings.hasErrDataBudget(sContext, null /* curTime */)).isTrue();
            Settings.consumeErrDataBudget(sContext, Settings.FAILURE_DATA_USAGE_MAX);
            ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
                    sContext);

            mServerInterface.fetchGeek(metrics);
            assertWithMessage("Network transaction should not have proceeded.").fail();
        } catch (RkpdException e) {
            assertThat(e).hasMessageThat().contains("Out of data budget due to repeated errors");
            assertThat(e.getErrorCode()).isEqualTo(
                    RkpdException.ErrorCode.NETWORK_COMMUNICATION_ERROR);
        }
    }

    @Test
    public void testNetworkDisconnected() throws Exception {
        try (FakeRkpServer server = new FakeRkpServer(
                FakeRkpServer.Response.FETCH_EEK_OK,
                FakeRkpServer.Response.SIGN_CERTS_OK_VALID_CBOR)) {
            Settings.setDeviceConfig(sContext, 2 /* extraKeys */,
                    TIME_TO_REFRESH_HOURS /* expiringBy */, server.getUrl());

            ProvisioningAttempt metrics = ProvisioningAttempt.createScheduledAttemptMetrics(
                    sContext);

            // We are okay in mocking connectivity failure since network check is the first thing
            // to happen.
            Utils.mockConnectivityState(sContext, Utils.ConnectivityState.DISCONNECTED);
            mServerInterface.fetchGeek(metrics);
            assertWithMessage("Network transaction should not have proceeded.").fail();
        } catch (RkpdException e) {
            assertThat(e).hasMessageThat().contains("No network detected");
            assertThat(e.getErrorCode()).isEqualTo(RkpdException.ErrorCode.NO_NETWORK_CONNECTIVITY);
        }
    }

    @Test
    public void testReadErrorInvalidContentType() {
        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("application/NOPE");
        assertThat(ServerInterface.readErrorFromConnection(connection))
                .isEqualTo("Unexpected content type from the server: application/NOPE");
    }

    @Test
    public void testReadTextErrorFromErrorStreamNoErrorData() throws Exception {
        final String expectedError = "No error data returned by server.";

        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("text");
        Mockito.when(connection.getInputStream()).thenThrow(new IOException());
        Mockito.when(connection.getErrorStream()).thenReturn(null);

        assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(expectedError);
    }

    @Test
    public void testReadTextErrorFromErrorStream() throws Exception {
        final String error = "Explanation for error goes here.";

        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("text");
        Mockito.when(connection.getInputStream()).thenThrow(new IOException());
        Mockito.when(connection.getErrorStream())
                .thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));

        assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(error);
    }

    @Test
    public void testReadTextError() throws IOException {
        final String error = "This is an error.  Oh No.";
        final String[] textContentTypes = new String[]{
                "text",
                "text/ANYTHING",
                "text/what-is-this; charset=unknown",
                "text/lowercase; charset=utf-8",
                "text/uppercase; charset=UTF-8",
                "text/yolo; charset=ASCII"
        };

        for (String contentType : textContentTypes) {
            HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
            Mockito.when(connection.getContentType()).thenReturn(contentType);
            Mockito.when(connection.getInputStream())
                    .thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));

            assertWithMessage("Failed on content type '" + contentType + "'")
                    .that(error)
                    .isEqualTo(ServerInterface.readErrorFromConnection(connection));
        }
    }

    @Test
    public void testReadJsonError() throws IOException {
        final String error = "Not really JSON.";

        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("application/json");
        Mockito.when(connection.getInputStream())
                .thenReturn(new ByteArrayInputStream(error.getBytes(StandardCharsets.UTF_8)));

        assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(error);
    }

    @Test
    public void testReadErrorStreamThrowsException() throws IOException {
        InputStream stream = Mockito.mock(InputStream.class);
        Mockito.when(stream.read(Mockito.any())).thenThrow(new IOException());

        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("text");
        Mockito.when(connection.getInputStream()).thenReturn(stream);

        final String error = ServerInterface.readErrorFromConnection(connection);
        assertWithMessage("Error string: '" + error + "'")
                .that(error).startsWith("Error reading error string from server: ");
    }

    @Test
    public void testReadErrorEmptyStream() throws IOException {
        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("text");
        Mockito.when(connection.getInputStream())
                .thenReturn(new ByteArrayInputStream(new byte[0]));

        assertThat(ServerInterface.readErrorFromConnection(connection))
                .isEqualTo("No error data returned by server.");
    }

    @Test
    public void testReadErrorStreamTooLarge() throws IOException {
        final StringBuilder sb = new StringBuilder();
        for (int i = 0; i < 2048; ++i) {
            sb.append(i % 100);
        }
        final String bigString = sb.toString();

        HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
        Mockito.when(connection.getContentType()).thenReturn("text");
        Mockito.when(connection.getInputStream())
                .thenReturn(new ByteArrayInputStream(bigString.getBytes(StandardCharsets.UTF_8)));

        sb.setLength(1024);
        assertThat(ServerInterface.readErrorFromConnection(connection)).isEqualTo(sb.toString());
    }

    @Test
    public void testServerConnectionTimeout() {
        ServerInterface serverInterface = Mockito.spy(mServerInterface);
        Mockito.when(serverInterface.getRegionalProperty()).thenReturn("cn");
        assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
                ServerInterface.SYNC_CONNECT_TIMEOUT_RETRICTED_MS);

        Mockito.when(serverInterface.getRegionalProperty()).thenReturn("cn,us");
        assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
                ServerInterface.SYNC_CONNECT_TIMEOUT_RETRICTED_MS);

        Mockito.when(serverInterface.getRegionalProperty()).thenReturn(null);
        assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
                ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);

        Mockito.when(serverInterface.getRegionalProperty()).thenReturn("");
        assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
                ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);

        Mockito.when(serverInterface.getRegionalProperty()).thenReturn("us");
        assertThat(serverInterface.getConnectTimeoutMs()).isEqualTo(
                ServerInterface.SYNC_CONNECT_TIMEOUT_OPEN_MS);
    }

    @Test
    public void testConnectionConsent() throws Exception {
        String cnGmsFeature = "cn.google.services";
        PackageManager mockedPackageManager = Mockito.mock(PackageManager.class);
        Context mockedContext = Mockito.mock(Context.class);
        ApplicationInfo fakeApplicationInfo = new ApplicationInfo();

        Mockito.when(mockedContext.getPackageManager()).thenReturn(mockedPackageManager);
        Mockito.when(mockedPackageManager.hasSystemFeature(cnGmsFeature)).thenReturn(true);
        Mockito.when(mockedPackageManager.getApplicationInfo(Mockito.any(), Mockito.eq(0)))
                .thenReturn(fakeApplicationInfo);

        fakeApplicationInfo.enabled = false;
        assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isFalse();

        fakeApplicationInfo.enabled = true;
        assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();

        Mockito.when(mockedPackageManager.getApplicationInfo(Mockito.any(), Mockito.eq(0)))
                .thenThrow(new PackageManager.NameNotFoundException());
        assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isFalse();

        Mockito.when(mockedPackageManager.hasSystemFeature(cnGmsFeature)).thenReturn(false);
        assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();

        fakeApplicationInfo.enabled = false;
        assertThat(ServerInterface.assumeNetworkConsent(mockedContext)).isTrue();
    }
}
