/*
 * Copyright 2015, Google Inc. All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are
 * met:
 *
 *    * Redistributions of source code must retain the above copyright
 * notice, this list of conditions and the following disclaimer.
 *    * Redistributions in binary form must reproduce the above
 * copyright notice, this list of conditions and the following disclaimer
 * in the documentation and/or other materials provided with the
 * distribution.
 *
 *    * Neither the name of Google Inc. nor the names of its
 * contributors may be used to endorse or promote products derived from
 * this software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
 * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
 * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package com.google.auth.oauth2;

import com.google.api.client.util.Clock;
import com.google.auth.Credentials;
import com.google.auth.RequestMetadataCallback;
import com.google.auth.http.AuthHttpConstants;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.MoreObjects;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.util.concurrent.AbstractFuture;
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListenableFutureTask;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.errorprone.annotations.CanIgnoreReturnValue;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.Serializable;
import java.net.URI;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.ServiceLoader;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;

/** Base type for Credentials using OAuth2. */
public class OAuth2Credentials extends Credentials {

  private static final long serialVersionUID = 4556936364828217687L;
  static final Duration DEFAULT_EXPIRATION_MARGIN = Duration.ofMinutes(3);
  static final Duration DEFAULT_REFRESH_MARGIN = Duration.ofMinutes(3).plusSeconds(45);
  private static final ImmutableMap<String, List<String>> EMPTY_EXTRA_HEADERS = ImmutableMap.of();

  @VisibleForTesting private final Duration expirationMargin;
  @VisibleForTesting private final Duration refreshMargin;

  // byte[] is serializable, so the lock variable can be final
  @VisibleForTesting final Object lock = new byte[0];
  private volatile OAuthValue value = null;
  @VisibleForTesting transient RefreshTask refreshTask;

  // Change listeners are not serialized
  private transient List<CredentialsChangedListener> changeListeners;
  // Until we expose this to the users it can remain transient and non-serializable
  @VisibleForTesting transient Clock clock = Clock.SYSTEM;

  /**
   * Returns the credentials instance from the given access token.
   *
   * @param accessToken the access token
   * @return the credentials instance
   */
  public static OAuth2Credentials create(AccessToken accessToken) {
    return OAuth2Credentials.newBuilder().setAccessToken(accessToken).build();
  }

  /** Default constructor. */
  protected OAuth2Credentials() {
    this(null);
  }

  /**
   * Constructor with explicit access token.
   *
   * @param accessToken initial or temporary access token
   */
  protected OAuth2Credentials(AccessToken accessToken) {
    this(accessToken, DEFAULT_REFRESH_MARGIN, DEFAULT_EXPIRATION_MARGIN);
  }

  protected OAuth2Credentials(
      AccessToken accessToken, Duration refreshMargin, Duration expirationMargin) {
    if (accessToken != null) {
      this.value = OAuthValue.create(accessToken, EMPTY_EXTRA_HEADERS);
    }

    this.refreshMargin = Preconditions.checkNotNull(refreshMargin, "refreshMargin");
    Preconditions.checkArgument(!refreshMargin.isNegative(), "refreshMargin can't be negative");
    this.expirationMargin = Preconditions.checkNotNull(expirationMargin, "expirationMargin");
    Preconditions.checkArgument(
        !expirationMargin.isNegative(), "expirationMargin can't be negative");
  }

  @Override
  public String getAuthenticationType() {
    return "OAuth2";
  }

  @Override
  public boolean hasRequestMetadata() {
    return true;
  }

  @Override
  public boolean hasRequestMetadataOnly() {
    return true;
  }

  /**
   * Returns the cached access token.
   *
   * <p>If not set, you should call {@link #refresh()} to fetch and cache an access token.
   *
   * @return The cached access token.
   */
  public final AccessToken getAccessToken() {
    OAuthValue localState = value;
    if (localState != null) {
      return localState.temporaryAccess;
    }
    return null;
  }

  /** Returns the credentials' refresh margin. */
  @VisibleForTesting
  Duration getRefreshMargin() {
    return this.refreshMargin;
  }

  /** Returns the credentials' expiration margin. */
  @VisibleForTesting
  Duration getExpirationMargin() {
    return this.expirationMargin;
  }

  @Override
  public void getRequestMetadata(
      final URI uri, Executor executor, final RequestMetadataCallback callback) {

    Futures.addCallback(
        asyncFetch(executor),
        new FutureCallbackToMetadataCallbackAdapter(callback),
        MoreExecutors.directExecutor());
  }

  /**
   * Provide the request metadata by ensuring there is a current access token and providing it as an
   * authorization bearer token.
   */
  @Override
  public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
    return unwrapDirectFuture(asyncFetch(MoreExecutors.directExecutor())).requestMetadata;
  }

  /**
   * Request a new token regardless of the current token state. If the current token is not expired,
   * it will still be returned during the refresh.
   */
  @Override
  public void refresh() throws IOException {
    AsyncRefreshResult refreshResult = getOrCreateRefreshTask();
    refreshResult.executeIfNew(MoreExecutors.directExecutor());
    unwrapDirectFuture(refreshResult.task);
  }

  /**
   * Refresh these credentials only if they have expired or are expiring imminently.
   *
   * @throws IOException during token refresh.
   */
  public void refreshIfExpired() throws IOException {
    // asyncFetch will ensure that the token is refreshed
    unwrapDirectFuture(asyncFetch(MoreExecutors.directExecutor()));
  }

  /**
   * Attempts to get a fresh token.
   *
   * <p>If a fresh token is already available, it will be immediately returned. Otherwise a refresh
   * will be scheduled using the passed in executor. While a token is being freshed, a stale value
   * will be returned.
   */
  private ListenableFuture<OAuthValue> asyncFetch(Executor executor) {
    AsyncRefreshResult refreshResult = null;

    // fast and common path: skip the lock if the token is fresh
    // The inherent race condition here is a non-issue: even if the value gets replaced after the
    // state check, the new token will still be fresh.
    if (getState() == CacheState.FRESH) {
      return Futures.immediateFuture(value);
    }

    // Schedule a refresh as necessary
    synchronized (lock) {
      if (getState() != CacheState.FRESH) {
        refreshResult = getOrCreateRefreshTask();
      }
    }
    // Execute the refresh if necessary. This should be done outside of the lock to avoid blocking
    // metadata requests during a stale refresh.
    if (refreshResult != null) {
      refreshResult.executeIfNew(executor);
    }

    synchronized (lock) {
      // Immediately resolve the token token if its not expired, or wait for the refresh task to
      // complete
      if (getState() != CacheState.EXPIRED) {
        return Futures.immediateFuture(value);
      } else if (refreshResult != null) {
        return refreshResult.task;
      } else {
        // Should never happen
        return Futures.immediateFailedFuture(
            new IllegalStateException("Credentials expired, but there is no task to refresh"));
      }
    }
  }

  /**
   * Atomically creates a single flight refresh token task.
   *
   * <p>Only a single refresh task can be scheduled at a time. If there is an existing task, it will
   * be returned for subsequent invocations. However if a new task is created, it is the
   * responsibility of the caller to execute it. The task will clear the single flight slow upon
   * completion.
   */
  private AsyncRefreshResult getOrCreateRefreshTask() {
    synchronized (lock) {
      if (refreshTask != null) {
        return new AsyncRefreshResult(refreshTask, false);
      }

      final ListenableFutureTask<OAuthValue> task =
          ListenableFutureTask.create(
              new Callable<OAuthValue>() {
                @Override
                public OAuthValue call() throws Exception {
                  return OAuthValue.create(refreshAccessToken(), getAdditionalHeaders());
                }
              });

      refreshTask = new RefreshTask(task, new RefreshTaskListener(task));

      return new AsyncRefreshResult(refreshTask, true);
    }
  }

  /**
   * Async callback for committing the result from a token refresh.
   *
   * <p>The result will be stored, listeners are invoked and the single flight slot is cleared.
   */
  private void finishRefreshAsync(ListenableFuture<OAuthValue> finishedTask) {
    synchronized (lock) {
      try {
        this.value = Futures.getDone(finishedTask);
        for (CredentialsChangedListener listener : changeListeners) {
          listener.onChanged(this);
        }
      } catch (Exception e) {
        // noop
      } finally {
        if (this.refreshTask != null && this.refreshTask.getTask() == finishedTask) {
          this.refreshTask = null;
        }
      }
    }
  }

  /**
   * Unwraps the value from the future.
   *
   * <p>Under most circumstances, the underlying future will already be resolved by the
   * DirectExecutor. In those cases, the error stacktraces will be rooted in the caller's call tree.
   * However, in some cases when async and sync usage is mixed, it's possible that a blocking call
   * will await an async future. In those cases, the stacktrace will be orphaned and be rooted in a
   * thread of whatever executor the async call used. This doesn't affect correctness and is
   * extremely unlikely.
   */
  private static <T> T unwrapDirectFuture(ListenableFuture<T> future) throws IOException {
    try {
      return future.get();
    } catch (InterruptedException e) {
      Thread.currentThread().interrupt();
      throw new IOException("Interrupted while asynchronously refreshing the access token", e);
    } catch (ExecutionException e) {
      Throwable cause = e.getCause();
      if (cause instanceof IOException) {
        throw (IOException) cause;
      } else if (cause instanceof RuntimeException) {
        throw (RuntimeException) cause;
      } else {
        throw new IOException("Unexpected error refreshing access token", cause);
      }
    }
  }

  /** Computes the effective credential state in relation to the current time. */
  private CacheState getState() {
    OAuthValue localValue = value;

    if (localValue == null) {
      return CacheState.EXPIRED;
    }
    Date expirationTime = localValue.temporaryAccess.getExpirationTime();

    if (expirationTime == null) {
      return CacheState.FRESH;
    }

    Duration remaining = Duration.ofMillis(expirationTime.getTime() - clock.currentTimeMillis());
    if (remaining.compareTo(expirationMargin) <= 0) {
      return CacheState.EXPIRED;
    }

    if (remaining.compareTo(refreshMargin) <= 0) {
      return CacheState.STALE;
    }

    return CacheState.FRESH;
  }

  /**
   * Method to refresh the access token according to the specific type of credentials.
   *
   * <p>Throws IllegalStateException if not overridden since direct use of OAuth2Credentials is only
   * for temporary or non-refreshing access tokens.
   *
   * @return never
   * @throws IllegalStateException always. OAuth2Credentials does not support refreshing the access
   *     token. An instance with a new access token or a derived type that supports refreshing
   *     should be used instead.
   */
  public AccessToken refreshAccessToken() throws IOException {
    throw new IllegalStateException(
        "OAuth2Credentials instance does not support refreshing the"
            + " access token. An instance with a new access token should be used, or a derived type"
            + " that supports refreshing.");
  }

  /**
   * Provide additional headers to return as request metadata.
   *
   * @return additional headers
   */
  protected Map<String, List<String>> getAdditionalHeaders() {
    return EMPTY_EXTRA_HEADERS;
  }

  /**
   * Adds a listener that is notified when the Credentials data changes.
   *
   * <p>This is called when token content changes, such as when the access token is refreshed. This
   * is typically used by code caching the access token.
   *
   * @param listener the listener to be added
   */
  public final void addChangeListener(CredentialsChangedListener listener) {
    synchronized (lock) {
      if (changeListeners == null) {
        changeListeners = new ArrayList<>();
      }
      changeListeners.add(listener);
    }
  }

  /**
   * Removes a listener that was added previously.
   *
   * @param listener The listener to be removed.
   */
  public final void removeChangeListener(CredentialsChangedListener listener) {
    synchronized (lock) {
      if (changeListeners != null) {
        changeListeners.remove(listener);
      }
    }
  }

  /**
   * Listener for changes to credentials.
   *
   * <p>This is called when token content changes, such as when the access token is refreshed. This
   * is typically used by code caching the access token.
   */
  public interface CredentialsChangedListener {

    /**
     * Notifies that the credentials have changed.
     *
     * <p>This is called when token content changes, such as when the access token is refreshed.
     * This is typically used by code caching the access token.
     *
     * @param credentials The updated credentials instance
     * @throws IOException My be thrown by listeners if saving credentials fails.
     */
    void onChanged(OAuth2Credentials credentials) throws IOException;
  }

  @Override
  public int hashCode() {
    return Objects.hashCode(value);
  }

  @Nullable
  protected Map<String, List<String>> getRequestMetadataInternal() {
    OAuthValue localValue = value;
    if (localValue != null) {
      return localValue.requestMetadata;
    }
    return null;
  }

  @Override
  public String toString() {
    OAuthValue localValue = value;

    Map<String, List<String>> requestMetadata = null;
    AccessToken temporaryAccess = null;

    if (localValue != null) {
      requestMetadata = localValue.requestMetadata;
      temporaryAccess = localValue.temporaryAccess;
    }
    return MoreObjects.toStringHelper(this)
        .add("requestMetadata", requestMetadata)
        .add("temporaryAccess", temporaryAccess)
        .toString();
  }

  @Override
  public boolean equals(Object obj) {
    if (!(obj instanceof OAuth2Credentials)) {
      return false;
    }
    OAuth2Credentials other = (OAuth2Credentials) obj;
    return Objects.equals(this.value, other.value);
  }

  private void readObject(ObjectInputStream input) throws IOException, ClassNotFoundException {
    input.defaultReadObject();
    clock = Clock.SYSTEM;
    refreshTask = null;
  }

  @SuppressWarnings("unchecked")
  protected static <T> T newInstance(String className) throws IOException, ClassNotFoundException {
    try {
      return (T) Class.forName(className).newInstance();
    } catch (InstantiationException | IllegalAccessException e) {
      throw new IOException(e);
    }
  }

  protected static <T> T getFromServiceLoader(Class<? extends T> clazz, T defaultInstance) {
    return Iterables.getFirst(ServiceLoader.load(clazz), defaultInstance);
  }

  public static Builder newBuilder() {
    return new Builder();
  }

  public Builder toBuilder() {
    return new Builder(this);
  }

  /** Stores an immutable snapshot of the accesstoken owned by {@link OAuth2Credentials} */
  static class OAuthValue implements Serializable {
    private final AccessToken temporaryAccess;
    private final Map<String, List<String>> requestMetadata;

    static OAuthValue create(AccessToken token, Map<String, List<String>> additionalHeaders) {
      return new OAuthValue(
          token,
          ImmutableMap.<String, List<String>>builder()
              .put(
                  AuthHttpConstants.AUTHORIZATION,
                  ImmutableList.of(OAuth2Utils.BEARER_PREFIX + token.getTokenValue()))
              .putAll(additionalHeaders)
              .build());
    }

    private OAuthValue(AccessToken temporaryAccess, Map<String, List<String>> requestMetadata) {
      this.temporaryAccess = temporaryAccess;
      this.requestMetadata = requestMetadata;
    }

    @Override
    public boolean equals(Object obj) {
      if (!(obj instanceof OAuthValue)) {
        return false;
      }
      OAuthValue other = (OAuthValue) obj;
      return Objects.equals(this.requestMetadata, other.requestMetadata)
          && Objects.equals(this.temporaryAccess, other.temporaryAccess);
    }

    @Override
    public int hashCode() {
      return Objects.hash(temporaryAccess, requestMetadata);
    }
  }

  enum CacheState {
    FRESH,
    STALE,
    EXPIRED;
  }

  static class FutureCallbackToMetadataCallbackAdapter implements FutureCallback<OAuthValue> {
    private final RequestMetadataCallback callback;

    public FutureCallbackToMetadataCallbackAdapter(RequestMetadataCallback callback) {
      this.callback = callback;
    }

    @Override
    public void onSuccess(@Nullable OAuthValue value) {
      callback.onSuccess(value.requestMetadata);
    }

    @Override
    public void onFailure(Throwable throwable) {
      // refreshAccessToken will be invoked in an executor, so if it fails unwrap the underlying
      // error
      if (throwable instanceof ExecutionException) {
        throwable = throwable.getCause();
      }
      callback.onFailure(throwable);
    }
  }

  /**
   * Result from {@link com.google.auth.oauth2.OAuth2Credentials#getOrCreateRefreshTask()}.
   *
   * <p>Contains the the refresh task and a flag indicating if the task is newly created. If the
   * task is newly created, it is the caller's responsibility to execute it.
   */
  static class AsyncRefreshResult {
    private final RefreshTask task;
    private final boolean isNew;

    AsyncRefreshResult(RefreshTask task, boolean isNew) {
      this.task = task;
      this.isNew = isNew;
    }

    void executeIfNew(Executor executor) {
      if (isNew) {
        executor.execute(task);
      }
    }
  }

  @VisibleForTesting
  class RefreshTaskListener implements Runnable {
    private ListenableFutureTask<OAuthValue> task;

    RefreshTaskListener(ListenableFutureTask<OAuthValue> task) {
      this.task = task;
    }

    @Override
    public void run() {
      finishRefreshAsync(task);
    }
  }

  class RefreshTask extends AbstractFuture<OAuthValue> implements Runnable {
    private final ListenableFutureTask<OAuthValue> task;
    private final RefreshTaskListener listener;

    RefreshTask(ListenableFutureTask<OAuthValue> task, RefreshTaskListener listener) {
      this.task = task;
      this.listener = listener;

      // Update Credential state first
      task.addListener(listener, MoreExecutors.directExecutor());

      // Then notify the world
      Futures.addCallback(
          task,
          new FutureCallback<OAuthValue>() {
            @Override
            public void onSuccess(OAuthValue result) {
              RefreshTask.this.set(result);
            }

            @Override
            public void onFailure(Throwable t) {
              RefreshTask.this.setException(t);
            }
          },
          MoreExecutors.directExecutor());
    }

    public ListenableFutureTask<OAuthValue> getTask() {
      return this.task;
    }

    @Override
    public void run() {
      task.run();
    }
  }

  public static class Builder {

    private AccessToken accessToken;
    private Duration refreshMargin = DEFAULT_REFRESH_MARGIN;
    private Duration expirationMargin = DEFAULT_EXPIRATION_MARGIN;

    protected Builder() {}

    protected Builder(OAuth2Credentials credentials) {
      this.accessToken = credentials.getAccessToken();
      this.refreshMargin = credentials.refreshMargin;
      this.expirationMargin = credentials.expirationMargin;
    }

    @CanIgnoreReturnValue
    public Builder setAccessToken(AccessToken token) {
      this.accessToken = token;
      return this;
    }

    @CanIgnoreReturnValue
    public Builder setRefreshMargin(Duration refreshMargin) {
      this.refreshMargin = refreshMargin;
      return this;
    }

    public Duration getRefreshMargin() {
      return refreshMargin;
    }

    @CanIgnoreReturnValue
    public Builder setExpirationMargin(Duration expirationMargin) {
      this.expirationMargin = expirationMargin;
      return this;
    }

    public Duration getExpirationMargin() {
      return expirationMargin;
    }

    public AccessToken getAccessToken() {
      return accessToken;
    }

    public OAuth2Credentials build() {
      return new OAuth2Credentials(accessToken, refreshMargin, expirationMargin);
    }
  }
}
