# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for creating proto service and HTTP action handlers.

The `@proto_action` function annotates a method as implementing a proto service
method. The annotated method should have the type
`Callable[[RequestMessage], ResponseMessage]`. The decorator will take care of
transcoding to/from a HTTP request, similar to
https://cloud.google.com/endpoints/docs/grpc/transcoding. The transcoding only
supports proto-over-http ('?alt=proto').

The `@http_action` function annotates a method as implementing a HTTP action at
some request path. The annotated method will receive the request body, and
should return a `HttpResponse`.

The `create_handler` function merges one or more objects with decorated methods
into a single request handler that's compatible with `http.server`.
"""

import collections
import dataclasses
import enum
import gzip
import http
import http.server
import re
from typing import Any, Callable, Mapping, Match, Pattern, Type, TypeVar
import urllib.parse
import zlib

from absl import logging

from google.api import annotations_pb2
from google.protobuf import descriptor_pool
from google.protobuf import message
from google.protobuf import message_factory

_CallableT = TypeVar('_CallableT', bound=Callable)

_HTTP_ACTION_ATTR = '_http_action_data'
_FACTORY = message_factory.MessageFactory(descriptor_pool.Default())


@dataclasses.dataclass(frozen=True)
class HttpError(Exception):
  """An Exception specifying the HTTP error to return."""
  code: http.HTTPStatus


@dataclasses.dataclass(frozen=True)
class HttpResponse:
  """Information for a successful HTTP response."""
  body: bytes
  headers: Mapping[str, str] = dataclasses.field(default_factory=lambda: {})


def proto_action(*,
                 service=str,
                 method=str) -> Callable[[_CallableT], _CallableT]:
  """Decorator annotating a method as handling a proto service method.

  The `google.api.http` annotation on the method will determine what requests
  will be handled by the decorated function. Only a subset of methods and path
  patterns are currently supported.

  The decorated method will be called with the request message; it should return
  a response message or or throw an `HttpError`.

  Args:
    service: The full name of the proto service.
    method: The name of the method.

  Returns:
    An annotated function.
  """
  try:
    desc = _FACTORY.pool.FindServiceByName(service).FindMethodByName(method)
  except KeyError as e:
    raise ValueError(f'Unable to find /{service}.{method}.') from e

  rule = desc.GetOptions().Extensions[annotations_pb2.http]
  pattern_kind = rule.WhichOneof('pattern')
  try:
    http_method = _HttpMethod[pattern_kind.upper()]
  except KeyError as e:
    raise ValueError(
        f'The google.api.http annotation on /{service}.{method} is invalid '
        'or unsupported.') from e
  path = _convert_pattern(getattr(rule, pattern_kind), alt_proto=True)

  def handler(match: Match[str], body: bytes,
              fn: Callable[[message.Message], message.Message]) -> HttpResponse:
    request = _FACTORY.GetPrototype(desc.input_type)()
    if rule.body == '*':
      try:
        request.ParseFromString(body)
      except message.DecodeError as e:
        raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
    elif rule.body:
      setattr(request, rule.body, body)
    # Set any fields from the request path.
    for prop, value in match.groupdict().items():
      try:
        unescaped = urllib.parse.unquote(value)
      except UnicodeError as e:
        raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
      setattr(request, prop, unescaped)

    response_body = fn(request).SerializeToString()
    return HttpResponse(
        body=response_body,
        headers={
            'Content-Length': len(response_body),
            'Content-Type': 'application/x-protobuf'
        })

  def annotate_method(func: _CallableT) -> _CallableT:
    setattr(func, _HTTP_ACTION_ATTR,
            _HttpActionData(method=http_method, path=path, handler=handler))
    return func

  return annotate_method


def http_action(*, method: str,
                pattern: str) -> Callable[[_CallableT], _CallableT]:
  """Decorator annotating a method as an HTTP action handler.

  Request matching the method and pattern will be handled by the decorated
  method. The pattern may contain bracket-enclosed keywords (e.g.,
  '/data/{path}'), which will be matched against the request and passed
  to the decorated function as keyword arguments.

  The decorated method will be called with the request body (if any) and any
  keyword args from the path pattern; it should return a `HttpResponse` or throw
  an `HttpError`.

  Args:
    method: The type of HTTP method ('GET' or 'POST').
    pattern: The url pattern to match.

  Returns:
    An annotated function.
  """
  try:
    http_method = _HttpMethod[method.upper()]
  except KeyError as e:
    raise ValueError(f'unsupported HTTP method `{method}`') from e
  path = _convert_pattern(pattern)

  def handler(match: Match[str], body: bytes,
              fn: Callable[[bytes], HttpResponse]) -> HttpResponse:
    try:
      args = {k: urllib.parse.unquote(v) for k, v in match.groupdict().items()}
    except UnicodeError as e:
      raise HttpError(code=http.HTTPStatus.BAD_REQUEST) from e
    return fn(body, **args)

  def annotate_method(func: _CallableT) -> _CallableT:
    setattr(func, _HTTP_ACTION_ATTR,
            _HttpActionData(method=http_method, path=path, handler=handler))
    return func

  return annotate_method


def create_handler(*services: Any) -> Type[http.server.BaseHTTPRequestHandler]:
  """Builds a BaseHTTPRequestHandler that delegates to decorated methods.

  The returned BaseHTTPRequestHandler class will route requests to decorated
  methods of the provided services, or return 404 if the request path does not
  match any action handlers. If the request path matches multiple registered
  action handlers, it's unspecified which will be invoked.

  Args:
    *services: A list of objects with methods decorated with `@proto_action` or
      `@http_action`.

  Returns:
    A BaseHTTPRequestHandler subclass.
  """

  # Collect all handlers, keyed by HTTP method.
  handlers = collections.defaultdict(lambda: [])
  for service in services:
    for attr_name in dir(service):
      attr = getattr(service, attr_name)
      if not callable(attr):
        continue
      data = getattr(attr, _HTTP_ACTION_ATTR, None)
      if isinstance(data, _HttpActionData):
        handlers[data.method].append((data, attr))

  format_handlers = lambda h: ''.join([f'\n  * {e[0].path.pattern}' for e in h])
  logging.debug(
      'Creating HTTP request handler for path patterns:\nGET:%s\nPOST:%s',
      format_handlers(handlers[_HttpMethod.GET]),
      format_handlers(handlers[_HttpMethod.POST]))

  class RequestHandler(http.server.BaseHTTPRequestHandler):
    """Handler that delegates to `handlers`."""

    def do_GET(self) -> None:  # pylint:disable=invalid-name (override)
      self._handle_request(_HttpMethod.GET, read_body=False)

    def do_POST(self) -> None:  # pylint:disable=invalid-name (override)
      self._handle_request(_HttpMethod.POST)

    def _handle_request(self,
                        method: _HttpMethod,
                        read_body: bool = True) -> None:
      """Reads and delegates an incoming request to a registered handler."""
      for data, fn in handlers[method]:
        match = data.path.fullmatch(self.path)
        if match is None:
          continue

        try:
          body = self._read_body() if read_body else b''
          response = data.handler(match, body, fn)
        except HttpError as e:
          logging.debug('%s error: %s', self.path, e)
          return self.send_error(e.code)
        return self._send_response(response)

      # If no handler matched the path, return an error.
      self.send_error(http.HTTPStatus.NOT_FOUND)

    def _read_body(self) -> bytes:
      """Reads the body of the request."""
      body = self.rfile.read(int(self.headers['Content-Length']))
      if self.headers['Content-Encoding'] == 'gzip':
        try:
          body = gzip.decompress(body)
        except (gzip.BadGzipFile, zlib.error) as e:
          raise HttpError(http.HTTPStatus.BAD_REQUEST) from e
      elif self.headers['Content-Encoding']:
        logging.warning('Unsupported content encoding %s',
                        self.headers['Content-Encoding'])
        raise HttpError(http.HTTPStatus.BAD_REQUEST)
      return body

    def _send_response(self, response: HttpResponse) -> None:
      """Sends a successful response message."""
      self.send_response(http.HTTPStatus.OK)
      for keyword, value in response.headers.items():
        self.send_header(keyword, value)
      self.end_headers()
      self.wfile.write(response.body)

  return RequestHandler


class _HttpMethod(enum.Enum):
  GET = 1
  POST = 2


@dataclasses.dataclass(frozen=True)
class _HttpActionData:
  """Data tracked for HTTP actions.

  Attributes:
    method: The name of the HTTP method to handle.
    path: Requests matching this pattern will be handled.
    handler: The handler function, which receives the path match, request body,
      and decorated function.
  """
  method: _HttpMethod
  path: Pattern[str]
  handler: Callable[[Match[str], bytes, Callable[..., Any]], HttpResponse]


def _convert_pattern(pattern: str, alt_proto=False) -> Pattern[str]:
  """Converts a Google API pattern to a regexp with named groups."""
  # Subfields are not supported and will generate a regexp compilation error.
  pattern_regexp = re.sub(r'\\\{(.+?)\\\}', r'(?P<\1>[^/?]*)',
                          re.escape(pattern))
  if alt_proto:
    pattern_regexp += r'\?%24alt=proto'
  try:
    return re.compile(pattern_regexp)
  except re.error as e:
    raise ValueError(f'unable to convert `{pattern}` to a regexp') from e
