/*
 * Copyright (C) 2022 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.android.rkpdapp.provisioner;

import android.content.Context;
import android.os.RemoteException;
import android.util.Log;

import com.android.rkpdapp.GeekResponse;
import com.android.rkpdapp.RkpdException;
import com.android.rkpdapp.database.InstantConverter;
import com.android.rkpdapp.database.ProvisionedKey;
import com.android.rkpdapp.database.ProvisionedKeyDao;
import com.android.rkpdapp.database.RkpKey;
import com.android.rkpdapp.interfaces.ServerInterface;
import com.android.rkpdapp.interfaces.SystemInterface;
import com.android.rkpdapp.metrics.ProvisioningAttempt;
import com.android.rkpdapp.utils.Settings;
import com.android.rkpdapp.utils.StatsProcessor;
import com.android.rkpdapp.utils.X509Utils;

import java.security.cert.X509Certificate;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import co.nstant.in.cbor.CborException;

/**
 * Provides an easy package to run the provisioning process from start to finish, interfacing
 * with the system interface and the server backend in order to provision attestation certificates
 * to the device.
 */
public class Provisioner {
    private static final String TAG = "RkpdProvisioner";
    private static final int FAILURE_MAXIMUM = 5;
    private static final Object provisionKeysLock = new Object();

    private final Context mContext;
    private final ProvisionedKeyDao mKeyDao;
    private final boolean mIsAsync;

    public Provisioner(final Context applicationContext, ProvisionedKeyDao keyDao,
            boolean isAsync) {
        mContext = applicationContext;
        mKeyDao = keyDao;
        mIsAsync = isAsync;
    }

    /**
     * Check to see if we need to perform provisioning or not for the given
     * IRemotelyProvisionedComponent.
     * @param serviceName the name of the remotely provisioned component to be provisioned
     * @return true if the remotely provisioned component requires more keys, false if the pool
     *         of available keys is healthy.
     */
    public boolean isProvisioningNeeded(ProvisioningAttempt metrics, String serviceName) {
        return calculateKeysRequired(metrics, serviceName) > 0;
    }

    /**
     * Generate, sign and store remotely provisioned keys.
     */
    public void provisionKeys(ProvisioningAttempt metrics, SystemInterface systemInterface,
            GeekResponse geekResponse) throws CborException, RkpdException, InterruptedException {
        synchronized (provisionKeysLock) {
            try {
                int keysRequired = calculateKeysRequired(metrics, systemInterface.getServiceName());
                Log.i(TAG, "Requested number of keys for provisioning: " + keysRequired);
                if (keysRequired == 0) {
                    metrics.setStatus(ProvisioningAttempt.Status.NO_PROVISIONING_NEEDED);
                    return;
                }

                List<RkpKey> keysGenerated = generateKeys(metrics, keysRequired, systemInterface);
                checkForInterrupts();
                List<byte[]> certChains = fetchCertificates(metrics, keysGenerated, systemInterface,
                        geekResponse);
                checkForInterrupts();
                List<ProvisionedKey> keys = associateCertsWithKeys(certChains, keysGenerated);

                mKeyDao.insertKeys(keys);
                Log.i(TAG, "Total provisioned keys: " + keys.size());
                metrics.setStatus(ProvisioningAttempt.Status.KEYS_SUCCESSFULLY_PROVISIONED);
            } catch (InterruptedException e) {
                metrics.setStatus(ProvisioningAttempt.Status.INTERRUPTED);
                throw e;
            } catch (RkpdException e) {
                if (Settings.getFailureCounter(mContext) > FAILURE_MAXIMUM) {
                    Log.e(TAG, "Too many failures, resetting defaults.");
                    Settings.resetDefaultConfig(mContext);
                }
                // Rethrow to provide failure signal to caller
                throw e;
            }
        }
    }

    private List<RkpKey> generateKeys(ProvisioningAttempt metrics, int numKeysRequired,
            SystemInterface systemInterface)
            throws CborException, RkpdException, InterruptedException {
        List<RkpKey> keyArray = new ArrayList<>(numKeysRequired);
        checkForInterrupts();
        for (long i = 0; i < numKeysRequired; i++) {
            keyArray.add(systemInterface.generateKey(metrics));
        }
        return keyArray;
    }

    private List<byte[]> fetchCertificates(ProvisioningAttempt metrics, List<RkpKey> keysGenerated,
            SystemInterface systemInterface, GeekResponse geekResponse)
            throws RkpdException, CborException, InterruptedException {
        int provisionedSoFar = 0;
        List<byte[]> certChains = new ArrayList<>(keysGenerated.size());
        int maxBatchSize;
        try {
            maxBatchSize = systemInterface.getBatchSize();
        } catch (RemoteException e) {
            throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
                    "Error getting batch size from the system", e);
        }
        while (provisionedSoFar != keysGenerated.size()) {
            int batchSize = Math.min(keysGenerated.size() - provisionedSoFar, maxBatchSize);
            certChains.addAll(batchProvision(metrics, systemInterface, geekResponse,
                    keysGenerated.subList(provisionedSoFar, batchSize + provisionedSoFar)));
            provisionedSoFar += batchSize;
        }
        return certChains;
    }

    private List<byte[]> batchProvision(ProvisioningAttempt metrics,
            SystemInterface systemInterface,
            GeekResponse response, List<RkpKey> keysGenerated)
            throws RkpdException, CborException, InterruptedException {
        int batch_size = keysGenerated.size();
        if (batch_size < 1) {
            throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
                    "Request at least 1 key to be signed. Num requested: " + batch_size);
        }
        byte[] certRequest = systemInterface.generateCsr(metrics, response, keysGenerated);
        if (certRequest == null) {
            throw new RkpdException(RkpdException.ErrorCode.INTERNAL_ERROR,
                    "Failed to serialize payload");
        }
        return new ServerInterface(mContext, mIsAsync).requestSignedCertificates(certRequest,
                response.getChallenge(), metrics);
    }

    private List<ProvisionedKey> associateCertsWithKeys(List<byte[]> certChains,
            List<RkpKey> keysGenerated) throws RkpdException {
        List<ProvisionedKey> provisionedKeys = new ArrayList<>();
        for (byte[] chain : certChains) {
            X509Certificate[] certChain = X509Utils.formatX509Certs(chain);
            X509Certificate leafCertificate = certChain[0];
            long expirationDate = X509Utils.getExpirationTimeForCertificateChain(certChain)
                    .toInstant().toEpochMilli();
            byte[] rawPublicKey = X509Utils.getAndFormatRawPublicKey(leafCertificate);
            if (rawPublicKey == null) {
                Log.e(TAG, "Skipping malformed public key.");
                continue;
            }
            for (RkpKey key : keysGenerated) {
                if (Arrays.equals(key.getPublicKey(), rawPublicKey)) {
                    provisionedKeys.add(key.generateProvisionedKey(chain,
                            InstantConverter.fromTimestamp(expirationDate)));
                    keysGenerated.remove(key);
                    break;
                }
            }
        }
        return provisionedKeys;
    }

    /**
     * Calculate the number of keys to be provisioned.
     */
    private int calculateKeysRequired(ProvisioningAttempt metrics, String serviceName) {
        int numExtraAttestationKeys = Settings.getExtraSignedKeysAvailable(mContext);
        Instant expirationTime = Settings.getExpirationTime(mContext);
        StatsProcessor.PoolStats poolStats = StatsProcessor.processPool(mKeyDao, serviceName,
                numExtraAttestationKeys, expirationTime);
        metrics.setIsKeyPoolEmpty(poolStats.keysUnassigned == 0);
        return poolStats.keysToGenerate;
    }

    private void checkForInterrupts() throws InterruptedException {
        if (Thread.interrupted()) {
            throw new InterruptedException();
        }
    }

    /**
     * Clears bad attestation keys on the basis of information provided in the FetchGeek response.
     */
    public void clearBadAttestationKeys(GeekResponse resp) {
        if (resp.lastBadCertTimeStart == null || resp.lastBadCertTimeEnd == null) {
            // if there is no time sent, no need to do anything.
            return;
        }
        if (resp.lastBadCertTimeStart.equals(Settings.getLastBadCertTimeStart(mContext))
                && resp.lastBadCertTimeEnd.equals(Settings.getLastBadCertTimeEnd(mContext))) {
            // if the time is same as already stored version, no need to do anything.
            return;
        }
        // clear the attestation keys on the basis of time.
        checkAndDeleteBadKeys(resp.lastBadCertTimeStart, resp.lastBadCertTimeEnd);

        // store the time.
        Settings.setLastBadCertTimeRange(mContext, resp.lastBadCertTimeStart,
                resp.lastBadCertTimeEnd);
    }

    private void checkAndDeleteBadKeys(Instant startTime, Instant endTime) {
        try {
            List<ProvisionedKey> allKeys = mKeyDao.getAllKeys();
            for (int i = 0; i < allKeys.size(); i++) {
                ProvisionedKey key = allKeys.get(i);
                X509Certificate[] certChain = X509Utils.formatX509Certs(key.certificateChain);
                X509Certificate leafCertificate = certChain[0];
                Instant creationTime = leafCertificate.getNotBefore().toInstant()
                        .truncatedTo(ChronoUnit.MILLIS);

                if (!creationTime.isBefore(startTime) && !creationTime.isAfter(endTime)) {
                    mKeyDao.deleteKey(key.keyBlob);
                }
            }
        } catch (RkpdException ex) {
            Log.e(TAG, "Could not convert certificate chain to X509 certificates.", ex);
        }
    }
}
