# Copyright 2018 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import json
import os

import mock
import pytest
from six.moves import http_client

from google.auth import _helpers
from google.auth import crypt
from google.auth import exceptions
from google.auth import impersonated_credentials
from google.auth import transport
from google.auth.impersonated_credentials import Credentials
from google.oauth2 import credentials
from google.oauth2 import service_account

DATA_DIR = os.path.join(os.path.dirname(__file__), "", "data")

with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh:
    PRIVATE_KEY_BYTES = fh.read()

SERVICE_ACCOUNT_JSON_FILE = os.path.join(DATA_DIR, "service_account.json")

ID_TOKEN_DATA = (
    "eyJhbGciOiJSUzI1NiIsImtpZCI6ImRmMzc1ODkwOGI3OTIyOTNhZDk3N2Ew"
    "Yjk5MWQ5OGE3N2Y0ZWVlY2QiLCJ0eXAiOiJKV1QifQ.eyJhdWQiOiJodHRwc"
    "zovL2Zvby5iYXIiLCJhenAiOiIxMDIxMDE1NTA4MzQyMDA3MDg1NjgiLCJle"
    "HAiOjE1NjQ0NzUwNTEsImlhdCI6MTU2NDQ3MTQ1MSwiaXNzIjoiaHR0cHM6L"
    "y9hY2NvdW50cy5nb29nbGUuY29tIiwic3ViIjoiMTAyMTAxNTUwODM0MjAwN"
    "zA4NTY4In0.redacted"
)
ID_TOKEN_EXPIRY = 1564475051

with open(SERVICE_ACCOUNT_JSON_FILE, "r") as fh:
    SERVICE_ACCOUNT_INFO = json.load(fh)

SIGNER = crypt.RSASigner.from_string(PRIVATE_KEY_BYTES, "1")
TOKEN_URI = "https://example.com/oauth2/token"


@pytest.fixture
def mock_donor_credentials():
    with mock.patch("google.oauth2._client.jwt_grant", autospec=True) as grant:
        grant.return_value = (
            "source token",
            _helpers.utcnow() + datetime.timedelta(seconds=500),
            {},
        )
        yield grant


class MockResponse:
    def __init__(self, json_data, status_code):
        self.json_data = json_data
        self.status_code = status_code

    def json(self):
        return self.json_data


@pytest.fixture
def mock_authorizedsession_sign():
    with mock.patch(
        "google.auth.transport.requests.AuthorizedSession.request", autospec=True
    ) as auth_session:
        data = {"keyId": "1", "signedBlob": "c2lnbmF0dXJl"}
        auth_session.return_value = MockResponse(data, http_client.OK)
        yield auth_session


@pytest.fixture
def mock_authorizedsession_idtoken():
    with mock.patch(
        "google.auth.transport.requests.AuthorizedSession.request", autospec=True
    ) as auth_session:
        data = {"token": ID_TOKEN_DATA}
        auth_session.return_value = MockResponse(data, http_client.OK)
        yield auth_session


class TestImpersonatedCredentials(object):

    SERVICE_ACCOUNT_EMAIL = "service-account@example.com"
    TARGET_PRINCIPAL = "impersonated@project.iam.gserviceaccount.com"
    TARGET_SCOPES = ["https://www.googleapis.com/auth/devstorage.read_only"]
    DELEGATES = []
    LIFETIME = 3600
    SOURCE_CREDENTIALS = service_account.Credentials(
        SIGNER, SERVICE_ACCOUNT_EMAIL, TOKEN_URI
    )
    USER_SOURCE_CREDENTIALS = credentials.Credentials(token="ABCDE")
    IAM_ENDPOINT_OVERRIDE = (
        "https://us-east1-iamcredentials.googleapis.com/v1/projects/-"
        + "/serviceAccounts/{}:generateAccessToken".format(SERVICE_ACCOUNT_EMAIL)
    )

    def make_credentials(
        self,
        source_credentials=SOURCE_CREDENTIALS,
        lifetime=LIFETIME,
        target_principal=TARGET_PRINCIPAL,
        iam_endpoint_override=None,
    ):

        return Credentials(
            source_credentials=source_credentials,
            target_principal=target_principal,
            target_scopes=self.TARGET_SCOPES,
            delegates=self.DELEGATES,
            lifetime=lifetime,
            iam_endpoint_override=iam_endpoint_override,
        )

    def test_make_from_user_credentials(self):
        credentials = self.make_credentials(
            source_credentials=self.USER_SOURCE_CREDENTIALS
        )
        assert not credentials.valid
        assert credentials.expired

    def test_default_state(self):
        credentials = self.make_credentials()
        assert not credentials.valid
        assert credentials.expired

    def make_request(
        self,
        data,
        status=http_client.OK,
        headers=None,
        side_effect=None,
        use_data_bytes=True,
    ):
        response = mock.create_autospec(transport.Response, instance=False)
        response.status = status
        response.data = _helpers.to_bytes(data) if use_data_bytes else data
        response.headers = headers or {}

        request = mock.create_autospec(transport.Request, instance=False)
        request.side_effect = side_effect
        request.return_value = response

        return request

    @pytest.mark.parametrize("use_data_bytes", [True, False])
    def test_refresh_success(self, use_data_bytes, mock_donor_credentials):
        credentials = self.make_credentials(lifetime=None)
        token = "token"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body),
            status=http_client.OK,
            use_data_bytes=use_data_bytes,
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

    @pytest.mark.parametrize("use_data_bytes", [True, False])
    def test_refresh_success_iam_endpoint_override(
        self, use_data_bytes, mock_donor_credentials
    ):
        credentials = self.make_credentials(
            lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
        )
        token = "token"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body),
            status=http_client.OK,
            use_data_bytes=use_data_bytes,
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired
        # Confirm override endpoint used.
        request_kwargs = request.call_args[1]
        assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE

    @pytest.mark.parametrize("time_skew", [100, -100])
    def test_refresh_source_credentials(self, time_skew):
        credentials = self.make_credentials(lifetime=None)

        # Source credentials is refreshed only if it is expired within
        # _helpers.REFRESH_THRESHOLD from now. We add a time_skew to the expiry, so
        # source credentials is refreshed only if time_skew <= 0.
        credentials._source_credentials.expiry = (
            _helpers.utcnow()
            + _helpers.REFRESH_THRESHOLD
            + datetime.timedelta(seconds=time_skew)
        )
        credentials._source_credentials.token = "Token"

        with mock.patch(
            "google.oauth2.service_account.Credentials.refresh", autospec=True
        ) as source_cred_refresh:
            expire_time = (
                _helpers.utcnow().replace(microsecond=0)
                + datetime.timedelta(seconds=500)
            ).isoformat("T") + "Z"
            response_body = {"accessToken": "token", "expireTime": expire_time}
            request = self.make_request(
                data=json.dumps(response_body), status=http_client.OK
            )

            credentials.refresh(request)

            assert credentials.valid
            assert not credentials.expired

            # Source credentials is refreshed only if it is expired within
            # _helpers.REFRESH_THRESHOLD
            if time_skew > 0:
                source_cred_refresh.assert_not_called()
            else:
                source_cred_refresh.assert_called_once()

    def test_refresh_failure_malformed_expire_time(self, mock_donor_credentials):
        credentials = self.make_credentials(lifetime=None)
        token = "token"

        expire_time = (_helpers.utcnow() + datetime.timedelta(seconds=500)).isoformat(
            "T"
        )
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.OK
        )

        with pytest.raises(exceptions.RefreshError) as excinfo:
            credentials.refresh(request)

        assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

        assert not credentials.valid
        assert credentials.expired

    def test_refresh_failure_unauthorzed(self, mock_donor_credentials):
        credentials = self.make_credentials(lifetime=None)

        response_body = {
            "error": {
                "code": 403,
                "message": "The caller does not have permission",
                "status": "PERMISSION_DENIED",
            }
        }

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.UNAUTHORIZED
        )

        with pytest.raises(exceptions.RefreshError) as excinfo:
            credentials.refresh(request)

        assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

        assert not credentials.valid
        assert credentials.expired

    def test_refresh_failure_http_error(self, mock_donor_credentials):
        credentials = self.make_credentials(lifetime=None)

        response_body = {}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.HTTPException
        )

        with pytest.raises(exceptions.RefreshError) as excinfo:
            credentials.refresh(request)

        assert excinfo.match(impersonated_credentials._REFRESH_ERROR)

        assert not credentials.valid
        assert credentials.expired

    def test_expired(self):
        credentials = self.make_credentials(lifetime=None)
        assert credentials.expired

    def test_signer(self):
        credentials = self.make_credentials()
        assert isinstance(credentials.signer, impersonated_credentials.Credentials)

    def test_signer_email(self):
        credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
        assert credentials.signer_email == self.TARGET_PRINCIPAL

    def test_service_account_email(self):
        credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
        assert credentials.service_account_email == self.TARGET_PRINCIPAL

    def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign):
        credentials = self.make_credentials(lifetime=None)
        token = "token"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        token_response_body = {"accessToken": token, "expireTime": expire_time}

        response = mock.create_autospec(transport.Response, instance=False)
        response.status = http_client.OK
        response.data = _helpers.to_bytes(json.dumps(token_response_body))

        request = mock.create_autospec(transport.Request, instance=False)
        request.return_value = response

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

        signature = credentials.sign_bytes(b"signed bytes")
        assert signature == b"signature"

    def test_sign_bytes_failure(self):
        credentials = self.make_credentials(lifetime=None)

        with mock.patch(
            "google.auth.transport.requests.AuthorizedSession.request", autospec=True
        ) as auth_session:
            data = {"error": {"code": 403, "message": "unauthorized"}}
            auth_session.return_value = MockResponse(data, http_client.FORBIDDEN)

            with pytest.raises(exceptions.TransportError) as excinfo:
                credentials.sign_bytes(b"foo")
            assert excinfo.match("'code': 403")

    def test_with_quota_project(self):
        credentials = self.make_credentials()

        quota_project_creds = credentials.with_quota_project("project-foo")
        assert quota_project_creds._quota_project_id == "project-foo"

    @pytest.mark.parametrize("use_data_bytes", [True, False])
    def test_with_quota_project_iam_endpoint_override(
        self, use_data_bytes, mock_donor_credentials
    ):
        credentials = self.make_credentials(
            lifetime=None, iam_endpoint_override=self.IAM_ENDPOINT_OVERRIDE
        )
        token = "token"
        # iam_endpoint_override should be copied to created credentials.
        quota_project_creds = credentials.with_quota_project("project-foo")

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body),
            status=http_client.OK,
            use_data_bytes=use_data_bytes,
        )

        quota_project_creds.refresh(request)

        assert quota_project_creds.valid
        assert not quota_project_creds.expired
        # Confirm override endpoint used.
        request_kwargs = request.call_args[1]
        assert request_kwargs["url"] == self.IAM_ENDPOINT_OVERRIDE

    def test_id_token_success(
        self, mock_donor_credentials, mock_authorizedsession_idtoken
    ):
        credentials = self.make_credentials(lifetime=None)
        token = "token"
        target_audience = "https://foo.bar"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.OK
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

        id_creds = impersonated_credentials.IDTokenCredentials(
            credentials, target_audience=target_audience
        )
        id_creds.refresh(request)

        assert id_creds.token == ID_TOKEN_DATA
        assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY)

    def test_id_token_from_credential(
        self, mock_donor_credentials, mock_authorizedsession_idtoken
    ):
        credentials = self.make_credentials(lifetime=None)
        token = "token"
        target_audience = "https://foo.bar"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.OK
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

        id_creds = impersonated_credentials.IDTokenCredentials(
            credentials, target_audience=target_audience, include_email=True
        )
        id_creds = id_creds.from_credentials(target_credentials=credentials)
        id_creds.refresh(request)

        assert id_creds.token == ID_TOKEN_DATA
        assert id_creds._include_email is True

    def test_id_token_with_target_audience(
        self, mock_donor_credentials, mock_authorizedsession_idtoken
    ):
        credentials = self.make_credentials(lifetime=None)
        token = "token"
        target_audience = "https://foo.bar"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.OK
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

        id_creds = impersonated_credentials.IDTokenCredentials(
            credentials, include_email=True
        )
        id_creds = id_creds.with_target_audience(target_audience=target_audience)
        id_creds.refresh(request)

        assert id_creds.token == ID_TOKEN_DATA
        assert id_creds.expiry == datetime.datetime.fromtimestamp(ID_TOKEN_EXPIRY)
        assert id_creds._include_email is True

    def test_id_token_invalid_cred(
        self, mock_donor_credentials, mock_authorizedsession_idtoken
    ):
        credentials = None

        with pytest.raises(exceptions.GoogleAuthError) as excinfo:
            impersonated_credentials.IDTokenCredentials(credentials)

        assert excinfo.match("Provided Credential must be" " impersonated_credentials")

    def test_id_token_with_include_email(
        self, mock_donor_credentials, mock_authorizedsession_idtoken
    ):
        credentials = self.make_credentials(lifetime=None)
        token = "token"
        target_audience = "https://foo.bar"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.OK
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

        id_creds = impersonated_credentials.IDTokenCredentials(
            credentials, target_audience=target_audience
        )
        id_creds = id_creds.with_include_email(True)
        id_creds.refresh(request)

        assert id_creds.token == ID_TOKEN_DATA

    def test_id_token_with_quota_project(
        self, mock_donor_credentials, mock_authorizedsession_idtoken
    ):
        credentials = self.make_credentials(lifetime=None)
        token = "token"
        target_audience = "https://foo.bar"

        expire_time = (
            _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500)
        ).isoformat("T") + "Z"
        response_body = {"accessToken": token, "expireTime": expire_time}

        request = self.make_request(
            data=json.dumps(response_body), status=http_client.OK
        )

        credentials.refresh(request)

        assert credentials.valid
        assert not credentials.expired

        id_creds = impersonated_credentials.IDTokenCredentials(
            credentials, target_audience=target_audience
        )
        id_creds = id_creds.with_quota_project("project-foo")
        id_creds.refresh(request)

        assert id_creds.quota_project_id == "project-foo"
