# Copyright (C) 2020 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for atest."""
from __future__ import print_function

import getpass
import logging
import os
import pathlib
from pathlib import Path
from socket import socket
import subprocess
import time
from typing import Any, Callable
import uuid

from atest import atest_utils
from atest import constants
from atest.atest_enum import DetectType
from atest.metrics import metrics
import httplib2
from oauth2client import client as oauth2_client
from oauth2client import contrib as oauth2_contrib
from oauth2client import tools as oauth2_tools


class RunFlowFlags:
  """Flags for oauth2client.tools.run_flow."""

  def __init__(self, browser_auth):
    self.auth_host_port = [8080, 8090]
    self.auth_host_name = 'localhost'
    self.logging_level = 'ERROR'
    self.noauth_local_webserver = not browser_auth


class GCPHelper:
  """GCP bucket helper class."""

  def __init__(
      self,
      client_id=None,
      client_secret=None,
      user_agent=None,
      scope=constants.SCOPE_BUILD_API_SCOPE,
  ):
    """Init stuff for GCPHelper class.

    Args:
        client_id: String, client id from the cloud project.
        client_secret: String, client secret for the client_id.
        user_agent: The user agent for the credential.
        scope: String, scopes separated by space.
    """
    self.client_id = client_id
    self.client_secret = client_secret
    self.user_agent = user_agent
    self.scope = scope

  def get_refreshed_credential_from_file(self, creds_file_path):
    """Get refreshed credential from file.

    Args:
        creds_file_path: Credential file path.

    Returns:
        An oauth2client.OAuth2Credentials instance.
    """
    credential = self.get_credential_from_file(creds_file_path)
    if credential:
      try:
        credential.refresh(httplib2.Http())
      except oauth2_client.AccessTokenRefreshError as e:
        logging.debug('Token refresh error: %s', e)
      if not credential.invalid:
        return credential
    logging.debug('Cannot get credential.')
    return None

  def get_credential_from_file(self, creds_file_path):
    """Get credential from file.

    Args:
        creds_file_path: Credential file path.

    Returns:
        An oauth2client.OAuth2Credentials instance.
    """
    storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage(
        filename=os.path.abspath(creds_file_path),
        client_id=self.client_id,
        user_agent=self.user_agent,
        scope=self.scope,
    )
    return storage.get()

  def get_credential_with_auth_flow(self, creds_file_path):
    """Get Credential object from file.

    Get credential object from file. Run oauth flow if haven't authorized
    before.

    Args:
        creds_file_path: Credential file path.

    Returns:
        An oauth2client.OAuth2Credentials instance.
    """
    credentials = None
    # SSO auth
    try:
      token = self._get_sso_access_token()
      credentials = oauth2_client.AccessTokenCredentials(token, 'atest')
      if credentials:
        return credentials
    # pylint: disable=broad-except
    except Exception as e:
      logging.debug('Exception:%s', e)
    # GCP auth flow
    credentials = self.get_refreshed_credential_from_file(creds_file_path)
    if not credentials:
      storage = oauth2_contrib.multiprocess_file_storage.get_credential_storage(
          filename=os.path.abspath(creds_file_path),
          client_id=self.client_id,
          user_agent=self.user_agent,
          scope=self.scope,
      )
      return self._run_auth_flow(storage)
    return credentials

  def _run_auth_flow(self, storage):
    """Get user oauth2 credentials.

    Using the loopback IP address flow for desktop clients.

    Args:
        storage: GCP storage object.

    Returns:
        An oauth2client.OAuth2Credentials instance.
    """
    flags = RunFlowFlags(browser_auth=True)

    # Get a free port on demand.
    port = None
    while not port or port < 10000:
      with socket() as local_socket:
        local_socket.bind(('', 0))
        _, port = local_socket.getsockname()
    _localhost_port = port
    _direct_uri = f'http://localhost:{_localhost_port}'
    flow = oauth2_client.OAuth2WebServerFlow(
        client_id=self.client_id,
        client_secret=self.client_secret,
        scope=self.scope,
        user_agent=self.user_agent,
        redirect_uri=f'{_direct_uri}',
    )
    credentials = oauth2_tools.run_flow(flow=flow, storage=storage, flags=flags)
    return credentials

  @staticmethod
  def _get_sso_access_token():
    """Use stubby command line to exchange corp sso to a scoped oauth

    token.

    Returns:
        A token string.
    """
    if not constants.TOKEN_EXCHANGE_COMMAND:
      return None

    request = constants.TOKEN_EXCHANGE_REQUEST.format(
        user=getpass.getuser(), scope=constants.SCOPE
    )
    # The output format is: oauth2_token: "<TOKEN>"
    return subprocess.run(
        constants.TOKEN_EXCHANGE_COMMAND,
        input=request,
        check=True,
        text=True,
        shell=True,
        stdout=subprocess.PIPE,
    ).stdout.split('"')[1]


# TODO: The usage of build_client should be removed from this method because
# it's not related to this module. For now, we temporarily declare the return
# type hint for build_client_creator to be Any to avoid circular importing.
def do_upload_flow(
    extra_args: dict[str, str],
    build_client_creator: Callable,
    invocation_properties: dict[str, str] = None,
) -> tuple:
  """Run upload flow.

  Asking user's decision and do the related steps.

  Args:
      extra_args: Dict of extra args to add to test run.
      build_client_creator: A function that takes a credential and returns a
        BuildClient object.
      invocation_properties: Additional invocation properties to write into the
        invocation.

  Return:
      A tuple of credential object and invocation information dict.
  """
  invocation_properties = invocation_properties or {}
  fetch_cred_start = time.time()
  creds = fetch_credential()
  metrics.LocalDetectEvent(
      detect_type=DetectType.FETCH_CRED_MS,
      result=int((time.time() - fetch_cred_start) * 1000),
  )
  if creds:
    prepare_upload_start = time.time()
    build_client = build_client_creator(creds)
    inv, workunit, local_build_id, build_target = _prepare_data(
        build_client, invocation_properties
    )
    metrics.LocalDetectEvent(
        detect_type=DetectType.UPLOAD_PREPARE_MS,
        result=int((time.time() - prepare_upload_start) * 1000),
    )
    extra_args[constants.INVOCATION_ID] = inv['invocationId']
    extra_args[constants.WORKUNIT_ID] = workunit['id']
    extra_args[constants.LOCAL_BUILD_ID] = local_build_id
    extra_args[constants.BUILD_TARGET] = build_target
    if not os.path.exists(os.path.dirname(constants.TOKEN_FILE_PATH)):
      os.makedirs(os.path.dirname(constants.TOKEN_FILE_PATH))
    with open(constants.TOKEN_FILE_PATH, 'w') as token_file:
      if creds.token_response:
        token_file.write(creds.token_response['access_token'])
      else:
        token_file.write(creds.access_token)
    return creds, inv
  return None, None


def fetch_credential():
  """Fetch the credential object."""
  creds_path = atest_utils.get_config_folder().joinpath(
      constants.CREDENTIAL_FILE_NAME
  )
  return GCPHelper(
      client_id=constants.CLIENT_ID,
      client_secret=constants.CLIENT_SECRET,
      user_agent='atest',
  ).get_credential_with_auth_flow(creds_path)


def _prepare_data(client, invocation_properties: dict[str, str]):
  """Prepare data for build api using.

  Args:
      build_client: The logstorage_utils.BuildClient object.
      invocation_properties: Additional invocation properties to write into the
        invocation.

  Return:
      invocation and workunit object.
      build id and build target of local build.
  """
  try:
    logging.disable(logging.INFO)
    external_id = str(uuid.uuid4())
    branch = _get_branch(client)
    target = _get_target(branch, client)
    build_record = client.insert_local_build(external_id, target, branch)
    client.insert_build_attempts(build_record)
    invocation = client.insert_invocation(build_record, invocation_properties)
    workunit = client.insert_work_unit(invocation)
    return invocation, workunit, build_record['buildId'], target
  finally:
    logging.disable(logging.NOTSET)


def _get_branch(build_client):
  """Get source code tree branch.

  Args:
      build_client: The build client object.

  Return:
      "git_main" in internal git, "aosp-main" otherwise.
  """
  default_branch = 'git_main' if constants.CREDENTIAL_FILE_NAME else 'aosp-main'
  local_branch = 'git_%s' % atest_utils.get_manifest_branch()
  branch = build_client.get_branch(local_branch)
  return local_branch if branch else default_branch


def _get_target(branch, build_client):
  """Get local build selected target.

  Args:
      branch: The branch want to check.
      build_client: The build client object.

  Return:
      The matched build target, "aosp_x86_64-trunk_staging-userdebug"
      otherwise.
  """
  default_target = 'aosp_x86_64-trunk_staging-userdebug'
  local_target = atest_utils.get_build_target()
  targets = [t['target'] for t in build_client.list_target(branch)['targets']]
  return local_target if local_target in targets else default_target
