/*
 * Copyright (C) 2024 The Android Open Source Project
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package android.adservices.ondevicepersonalization;

import android.adservices.ondevicepersonalization.aidl.IDataAccessService;
import android.adservices.ondevicepersonalization.aidl.IIsolatedModelService;
import android.adservices.ondevicepersonalization.aidl.IIsolatedModelServiceCallback;
import android.annotation.CallbackExecutor;
import android.annotation.FlaggedApi;
import android.annotation.NonNull;
import android.annotation.WorkerThread;
import android.os.Bundle;
import android.os.OutcomeReceiver;
import android.os.RemoteException;

import com.android.adservices.ondevicepersonalization.flags.Flags;
import com.android.ondevicepersonalization.internal.util.LoggerFactory;

import java.util.Objects;
import java.util.concurrent.Executor;

/**
 * Handles model inference and only support TFLite model inference now. See {@link
 * IsolatedService#getModelManager}.
 */
@FlaggedApi(Flags.FLAG_ON_DEVICE_PERSONALIZATION_APIS_ENABLED)
public class ModelManager {
    private static final LoggerFactory.Logger sLogger = LoggerFactory.getLogger();
    private static final String TAG = ModelManager.class.getSimpleName();
    @NonNull private final IDataAccessService mDataService;

    @NonNull private final IIsolatedModelService mModelService;

    /** @hide */
    public ModelManager(
            @NonNull IDataAccessService dataService, @NonNull IIsolatedModelService modelService) {
        mDataService = dataService;
        mModelService = modelService;
    }

    /**
     * Run a single model inference. Only supports TFLite model inference now.
     *
     * @param input contains all the information needed for a run of model inference.
     * @param executor the {@link Executor} on which to invoke the callback.
     * @param receiver this returns a {@link InferenceOutput} which contains model inference result
     *     or {@link Exception} on failure.
     */
    @WorkerThread
    public void run(
            @NonNull InferenceInput input,
            @NonNull @CallbackExecutor Executor executor,
            @NonNull OutcomeReceiver<InferenceOutput, Exception> receiver) {
        final long startTimeMillis = System.currentTimeMillis();
        Objects.requireNonNull(input);
        Bundle bundle = new Bundle();
        bundle.putBinder(Constants.EXTRA_DATA_ACCESS_SERVICE_BINDER, mDataService.asBinder());
        bundle.putParcelable(Constants.EXTRA_INFERENCE_INPUT, new InferenceInputParcel(input));
        try {
            mModelService.runInference(
                    bundle,
                    new IIsolatedModelServiceCallback.Stub() {
                        @Override
                        public void onSuccess(Bundle result) {
                            executor.execute(
                                    () -> {
                                        int responseCode = Constants.STATUS_SUCCESS;
                                        long endTimeMillis = System.currentTimeMillis();
                                        try {
                                            InferenceOutputParcel outputParcel =
                                                    Objects.requireNonNull(
                                                            result.getParcelable(
                                                                    Constants.EXTRA_RESULT,
                                                                    InferenceOutputParcel.class));
                                            InferenceOutput output =
                                                    new InferenceOutput(outputParcel.getData());
                                            endTimeMillis = System.currentTimeMillis();
                                            receiver.onResult(output);
                                        } catch (Exception e) {
                                            endTimeMillis = System.currentTimeMillis();
                                            responseCode = Constants.STATUS_INTERNAL_ERROR;
                                            receiver.onError(e);
                                        } finally {
                                            logApiCallStats(
                                                    Constants.API_NAME_MODEL_MANAGER_RUN,
                                                    endTimeMillis - startTimeMillis,
                                                    responseCode);
                                        }
                                    });
                        }

                        @Override
                        public void onError(int errorCode) {
                            executor.execute(
                                    () -> {
                                        long endTimeMillis = System.currentTimeMillis();
                                        receiver.onError(
                                            new IllegalStateException("Error: " + errorCode));
                                        logApiCallStats(
                                                Constants.API_NAME_MODEL_MANAGER_RUN,
                                                endTimeMillis - startTimeMillis,
                                                Constants.STATUS_INTERNAL_ERROR);
                                    });
                        }
                    });
        } catch (RemoteException e) {
            receiver.onError(new IllegalStateException(e));
        }
    }

    private void logApiCallStats(int apiName, long duration, int responseCode) {
        try {
            mDataService.logApiCallStats(apiName, duration, responseCode);
        } catch (Exception e) {
            sLogger.d(e, TAG + ": failed to log metrics");
        }
    }
}
