# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module related to code analysis and generation."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from ctypes import util
import itertools
import os
# pylint: disable=unused-import
from typing import (Text, List, Optional, Set, Dict, Callable, IO,
                    Generator as Gen, Tuple, Union, Sequence)  # pyformat: disable
# pylint: enable=unused-import
from clang import cindex


_PARSE_OPTIONS = (
    cindex.TranslationUnit.PARSE_SKIP_FUNCTION_BODIES
    | cindex.TranslationUnit.PARSE_INCOMPLETE |
    # for include directives
    cindex.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD)


def _init_libclang():
  """Finds and initializes the libclang library."""
  if cindex.Config.loaded:
    return
  # Try to find libclang in the standard location and a few versioned paths
  # that are used on Debian (and others). If LD_LIBRARY_PATH is set, it is
  # used as well.
  for version in [
      '',
      '16',
      '15',
      '14',
      '13',
      '12',
      '11',
      '10',
      '9',
      '8',
      '7',
      '6.0',
      '5.0',
      '4.0',
  ]:
    libname = 'clang' + ('-' + version if version else '')
    libclang = util.find_library(libname)
    if libclang:
      cindex.Config.set_library_file(libclang)
      break


def get_header_guard(path):
  # type: (Text) -> Text
  """Generates header guard string from path."""
  # the output file will be most likely somewhere in genfiles, strip the
  # prefix in that case, also strip .gen if this is a step before clang-format
  if not path:
    raise ValueError('Cannot prepare header guard from path: {}'.format(path))
  if 'genfiles/' in path:
    path = path.split('genfiles/')[1]
  if path.endswith('.gen'):
    path = path.split('.gen')[0]
  path = path.upper().replace('.', '_').replace('-', '_').replace('/', '_')
  return path + '_'


def _stringify_tokens(tokens, separator='\n'):
  # type: (Sequence[cindex.Token], Text) -> Text
  """Converts tokens to text respecting line position (disrespecting column)."""
  previous = OutputLine(0, [])  # not used in output
  lines = []  # type: List[OutputLine]

  for _, group in itertools.groupby(tokens, lambda t: t.location.line):
    group_list = list(group)
    line = OutputLine(previous.next_tab, group_list)

    lines.append(line)
    previous = line

  return separator.join(str(l) for l in lines)


TYPE_MAPPING = {
    cindex.TypeKind.VOID: '::sapi::v::Void',
    cindex.TypeKind.CHAR_S: '::sapi::v::Char',
    cindex.TypeKind.CHAR_U: '::sapi::v::Char',
    cindex.TypeKind.INT: '::sapi::v::Int',
    cindex.TypeKind.UINT: '::sapi::v::UInt',
    cindex.TypeKind.LONG: '::sapi::v::Long',
    cindex.TypeKind.ULONG: '::sapi::v::ULong',
    cindex.TypeKind.UCHAR: '::sapi::v::UChar',
    cindex.TypeKind.USHORT: '::sapi::v::UShort',
    cindex.TypeKind.SHORT: '::sapi::v::Short',
    cindex.TypeKind.LONGLONG: '::sapi::v::LLong',
    cindex.TypeKind.ULONGLONG: '::sapi::v::ULLong',
    cindex.TypeKind.FLOAT: '::sapi::v::Reg<float>',
    cindex.TypeKind.DOUBLE: '::sapi::v::Reg<double>',
    cindex.TypeKind.LONGDOUBLE: '::sapi::v::Reg<long double>',
    cindex.TypeKind.SCHAR: '::sapi::v::SChar',
    cindex.TypeKind.SHORT: '::sapi::v::Short',
    cindex.TypeKind.BOOL: '::sapi::v::Bool',
}


class Type(object):
  """Class representing a type.

  Wraps cindex.Type of the argument/return value and provides helpers for the
  code generation.
  """

  def __init__(self, tu, clang_type):
    # type: (_TranslationUnit, cindex.Type) -> None
    self._clang_type = clang_type
    self._tu = tu

  # pylint: disable=protected-access
  def __eq__(self, other):
    # type: (Type) -> bool
    # Use get_usr() to deduplicate Type objects based on declaration
    decl = self._get_declaration()
    decl_o = other._get_declaration()

    return decl.get_usr() == decl_o.get_usr()

  def __ne__(self, other):
    # type: (Type) -> bool
    return not self.__eq__(other)

  def __lt__(self, other):
    # type: (Type) -> bool
    """Compares two Types belonging to the same TranslationUnit.

    This is being used to properly order types before emitting to generated
    file. To be more specific: structure definition that contains field that is
    a typedef should end up after that typedef definition. This is achieved by
    exploiting the order in which clang iterate over AST in translation unit.

    Args:
      other: other comparison type

    Returns:
      true if this Type occurs earlier in the AST than 'other'
    """
    self._validate_tu(other)
    return (self._tu.order[self._get_declaration().hash] <
            self._tu.order[other._get_declaration().hash])  # pylint: disable=protected-access

  def __gt__(self, other):
    # type: (Type) -> bool
    """Compares two Types belonging to the same TranslationUnit.

    This is being used to properly order types before emitting to generated
    file. To be more specific: structure definition that contains field that is
    a typedef should end up after that typedef definition. This is achieved by
    exploiting the order in which clang iterate over AST in translation unit.

    Args:
      other: other comparison type

    Returns:
      true if this Type occurs later in the AST than 'other'
    """
    self._validate_tu(other)
    return (self._tu.order[self._get_declaration().hash] >
            self._tu.order[other._get_declaration().hash])  # pylint: disable=protected-access

  def __hash__(self):
    """Types with the same declaration should hash to the same value."""
    return hash(self._get_declaration().get_usr())

  def _validate_tu(self, other):
    # type: (Type) -> None
    if self._tu != other._tu:  # pylint: disable=protected-access
      raise ValueError('Cannot compare types from different translation units.')

  def is_void(self):
    # type: () -> bool
    return self._clang_type.kind == cindex.TypeKind.VOID

  def is_typedef(self):
    # type: () -> bool
    return self._clang_type.kind == cindex.TypeKind.TYPEDEF

  def is_elaborated(self):
    # type: () -> bool
    return self._clang_type.kind == cindex.TypeKind.ELABORATED

  # Hack: both class and struct types are indistinguishable except for
  # declaration cursor kind
  def is_sugared_record(self):  # class, struct, union
    # type: () -> bool
    return self._clang_type.get_declaration().kind in (
        cindex.CursorKind.STRUCT_DECL, cindex.CursorKind.UNION_DECL,
        cindex.CursorKind.CLASS_DECL)

  def is_struct(self):
    # type: () -> bool
    return (self._clang_type.get_declaration().kind ==
            cindex.CursorKind.STRUCT_DECL)

  def is_class(self):
    # type: () -> bool
    return (self._clang_type.get_declaration().kind ==
            cindex.CursorKind.CLASS_DECL)

  def is_union(self):
    # type: () -> bool
    return (self._clang_type.get_declaration().kind ==
            cindex.CursorKind.UNION_DECL)

  def is_function(self):
    # type: () -> bool
    return self._clang_type.kind == cindex.TypeKind.FUNCTIONPROTO

  def is_sugared_ptr(self):
    # type: () -> bool
    return self._clang_type.get_canonical().kind == cindex.TypeKind.POINTER

  def is_sugared_enum(self):
    # type: () -> bool
    return self._clang_type.get_canonical().kind == cindex.TypeKind.ENUM

  def is_const_array(self):
    # type: () -> bool
    return self._clang_type.kind == cindex.TypeKind.CONSTANTARRAY

  def is_simple_type(self):
    # type: () -> bool
    return self._clang_type.kind in TYPE_MAPPING

  def get_pointee(self):
    # type: () -> Type
    return Type(self._tu, self._clang_type.get_pointee())

  def _get_declaration(self):
    # type: () -> cindex.Cursor
    decl = self._clang_type.get_declaration()
    if decl.kind == cindex.CursorKind.NO_DECL_FOUND and self.is_sugared_ptr():
      decl = self.get_pointee()._get_declaration()  # pylint: disable=protected-access

    return decl

  def get_related_types(self, result=None, skip_self=False):
    # type: (Optional[Set[Type]], bool) -> Set[Type]
    """Returns all types related to this one eg. typedefs, nested structs."""
    if result is None:
      result = set()

    # Base case.
    if self in result or self.is_simple_type() or self.is_class():
      return result

    # Sugar types.
    if self.is_typedef():
      return self._get_related_types_of_typedef(result)

    if self.is_elaborated():
      return Type(self._tu,
                  self._clang_type.get_named_type()).get_related_types(
                      result, skip_self)

    # Composite types.
    if self.is_const_array():
      t = Type(self._tu, self._clang_type.get_array_element_type())
      return t.get_related_types(result)

    if self._clang_type.kind in (cindex.TypeKind.POINTER,
                                 cindex.TypeKind.MEMBERPOINTER,
                                 cindex.TypeKind.LVALUEREFERENCE,
                                 cindex.TypeKind.RVALUEREFERENCE):
      return self.get_pointee().get_related_types(result, skip_self)

    # union + struct, class should be filtered out
    if self.is_struct() or self.is_union():
      return self._get_related_types_of_record(result, skip_self)

    if self.is_function():
      return self._get_related_types_of_function(result)

    if self.is_sugared_enum():
      if not skip_self:
        result.add(self)
        self._tu.search_for_macro_name(self._get_declaration())
      return result

    # Ignore all cindex.TypeKind.UNEXPOSED AST nodes
    # TODO(b/256934562): Remove the disable once the pytype bug is fixed.
    return result  # pytype: disable=bad-return-type

  def _get_related_types_of_typedef(self, result):
    # type: (Set[Type]) -> Set[Type]
    """Returns all intermediate types related to the typedef."""
    result.add(self)
    decl = self._clang_type.get_declaration()
    self._tu.search_for_macro_name(decl)

    t = Type(self._tu, decl.underlying_typedef_type)
    if t.is_sugared_ptr():
      t = t.get_pointee()

    if not t.is_simple_type():
      skip_child = self.contains_declaration(t)
      if t.is_sugared_record() and skip_child:
        # if child declaration is contained in parent, we don't have to emit it
        self._tu.types_to_skip.add(t)
      result.update(t.get_related_types(result, skip_child))

    return result

  def _get_related_types_of_record(self, result, skip_self=False):
    # type: (Set[Type], bool) -> Set[Type]
    """Returns all types related to the structure."""
    # skip unnamed structures eg. typedef struct {...} x;
    # struct {...} will be rendered as part of typedef rendering
    decl = self._get_declaration()
    if not decl.is_anonymous() and not skip_self:
      self._tu.search_for_macro_name(decl)
      result.add(self)

    for f in self._clang_type.get_fields():
      self._tu.search_for_macro_name(f)
      result.update(Type(self._tu, f.type).get_related_types(result))

    return result

  def _get_related_types_of_function(self, result):
    # type: (Set[Type]) -> Set[Type]
    """Returns all types related to the function."""
    for arg in self._clang_type.argument_types():
      result.update(Type(self._tu, arg).get_related_types(result))
    related = Type(self._tu,
                   self._clang_type.get_result()).get_related_types(result)
    result.update(related)

    return result

  def contains_declaration(self, other):
    # type: (Type) -> bool
    """Checks if string representation of a type contains the other type."""
    self_extent = self._get_declaration().extent
    other_extent = other._get_declaration().extent  # pylint: disable=protected-access

    if other_extent.start.file is None:
      return False
    return (other_extent.start in self_extent and
            other_extent.end in self_extent)

  def stringify(self):
    # type: () -> Text
    """Returns string representation of the Type."""
    # (szwl): as simple as possible, keeps macros in separate lines not to
    # break things; this will go through clang format nevertheless
    tokens = [
        x for x in self._get_declaration().get_tokens()
        if x.kind is not cindex.TokenKind.COMMENT
    ]

    return _stringify_tokens(tokens)


class OutputLine(object):
  """Helper class for Type printing."""

  def __init__(self, tab, tokens):
    # type: (int, List[cindex.Token]) -> None
    self.tokens = tokens
    self.spellings = []
    self.define = False
    self.tab = tab
    self.next_tab = tab
    list(map(self._process_token, self.tokens))

  def _process_token(self, t):
    # type: (cindex.Token) -> None
    """Processes a token, setting up internal states rel. to intendation."""
    if t.spelling == '#':
      self.define = True
    elif t.spelling == '{':
      self.next_tab += 1
    elif t.spelling == '}':
      self.tab -= 1
      self.next_tab -= 1

    is_bracket = t.spelling == '('
    is_macro = len(self.spellings) == 1 and self.spellings[0] == '#'
    if self.spellings and not is_bracket and not is_macro:
      self.spellings.append(' ')
    self.spellings.append(t.spelling)

  def __str__(self):
    # type: () -> Text
    tabs = ('\t' * self.tab) if not self.define else ''
    return tabs + ''.join(t for t in self.spellings)


class ArgumentType(Type):
  """Class representing function argument type.

  Object fields are being used by the code template:
  pos: argument position
  type: string representation of the type
  argument: string representation of the type as function argument
  mapped_type: SAPI equivalent of the type
  wrapped: wraps type in SAPI object constructor
  call_argument: type (or it's sapi wrapper) used in function call
  """

  def __init__(self, function, pos, arg_type, name=None):
    # type: (Function, int, cindex.Type, Optional[Text]) -> None
    super(ArgumentType, self).__init__(function.translation_unit(), arg_type)
    self._function = function

    self.pos = pos
    self.name = name or 'a{}'.format(pos)
    self.type = arg_type.spelling

    template = '{}' if self.is_sugared_ptr() else '&{}_'
    self.call_argument = template.format(self.name)

  def __str__(self):
    # type: () -> Text
    """Returns function argument prepared from the type."""
    if self.is_sugared_ptr():
      return '::sapi::v::Ptr* {}'.format(self.name)

    return '{} {}'.format(self._clang_type.spelling, self.name)

  @property
  def wrapped(self):
    # type: () -> Text
    return '{} {name}_(({name}))'.format(self.mapped_type, name=self.name)

  @property
  def mapped_type(self):
    # type: () -> Text
    """Maps the type to its SAPI equivalent."""
    if self.is_sugared_ptr():
      # TODO(szwl): const ptrs do not play well with SAPI C++ API...
      spelling = self._clang_type.spelling.replace('const', '')
      return '::sapi::v::Reg<{}>'.format(spelling)

    type_ = self._clang_type

    if type_.kind == cindex.TypeKind.TYPEDEF:
      type_ = self._clang_type.get_canonical()
    if type_.kind == cindex.TypeKind.ELABORATED:
      type_ = type_.get_canonical()
    if type_.kind == cindex.TypeKind.ENUM:
      return '::sapi::v::IntBase<{}>'.format(self._clang_type.spelling)
    if type_.kind in [
        cindex.TypeKind.CONSTANTARRAY, cindex.TypeKind.INCOMPLETEARRAY
    ]:
      return '::sapi::v::Reg<{}>'.format(self._clang_type.spelling)

    if type_.kind == cindex.TypeKind.LVALUEREFERENCE:
      return 'LVALUEREFERENCE::NOT_SUPPORTED'

    if type_.kind == cindex.TypeKind.RVALUEREFERENCE:
      return 'RVALUEREFERENCE::NOT_SUPPORTED'

    if type_.kind in [cindex.TypeKind.RECORD, cindex.TypeKind.ELABORATED]:
      raise ValueError('Elaborate type (eg. struct) in mapped_type is not '
                       'supported: function {}, arg {}, type {}, location {}'
                       ''.format(self._function.name, self.pos,
                                 self._clang_type.spelling,
                                 self._function.cursor.location))

    if type_.kind not in TYPE_MAPPING:
      raise KeyError('Key {} does not exist in TYPE_MAPPING.'
                     ' function {}, arg {}, type {}, location {}'
                     ''.format(type_.kind, self._function.name, self.pos,
                               self._clang_type.spelling,
                               self._function.cursor.location))

    return TYPE_MAPPING[type_.kind]


class ReturnType(ArgumentType):
  """Class representing function return type.

     Attributes:
       return_type: absl::StatusOr<T> where T is original return type, or
                    absl::Status for functions returning void
  """

  def __init__(self, function, arg_type):
    # type: (Function, cindex.Type) -> None
    super(ReturnType, self).__init__(function, 0, arg_type, None)

  def __str__(self):
    # type: () -> Text
    """Returns function return type prepared from the type."""
    # TODO(szwl): const ptrs do not play well with SAPI C++ API...
    spelling = self._clang_type.spelling.replace('const', '')
    return_type = 'absl::StatusOr<{}>'.format(spelling)
    return_type = 'absl::Status' if self.is_void() else return_type
    return return_type


class Function(object):
  """Class representing SAPI-wrapped function used by the template.

  Wraps Clang cursor object of kind FUNCTION_DECL and provides helpers to
  aid code generation.
  """

  def __init__(self, tu, cursor):
    # type: (_TranslationUnit, cindex.Cursor) -> None
    self._tu = tu
    self.cursor = cursor  # type: cindex.Index
    self.name = cursor.spelling  # type: Text
    self.result = ReturnType(self, cursor.result_type)
    self.original_definition = '{} {}'.format(
        cursor.result_type.spelling, self.cursor.displayname)  # type: Text

    types = self.cursor.get_arguments()
    self.argument_types = [
        ArgumentType(self, i, t.type, t.spelling) for i, t in enumerate(types)
    ]

  def translation_unit(self):
    # type: () -> _TranslationUnit
    return self._tu

  def arguments(self):
    # type: () -> List[ArgumentType]
    return self.argument_types

  def call_arguments(self):
    # type: () -> List[Text]
    return [a.call_argument for a in self.argument_types]

  def get_absolute_path(self):
    # type: () -> Text
    return self.cursor.location.file.name

  def get_include_path(self, prefix):
    # type: (Optional[Text]) -> Text
    """Creates a proper include path."""
    # TODO(szwl): sanity checks
    # TODO(szwl): prefix 'utils/' and the path is '.../fileutils/...' case
    if prefix and not prefix.endswith('/'):
      prefix += '/'

    if not prefix:
      return self.get_absolute_path()
    elif prefix in self.get_absolute_path():
      return prefix + self.get_absolute_path().split(prefix)[-1]
    return prefix + self.get_absolute_path().split('/')[-1]

  def get_related_types(self, processed=None):
    # type: (Optional[Set[Type]]) -> Set[Type]
    result = self.result.get_related_types(processed)
    for a in self.argument_types:
      result.update(a.get_related_types(processed))

    return result

  def is_mangled(self):
    # type: () -> bool
    return self.cursor.mangled_name != self.cursor.spelling

  def __hash__(self):
    # type: () -> int
    return hash(self.cursor.get_usr())

  def __eq__(self, other):
    # type: (Function) -> bool
    return self.cursor.mangled_name == other.cursor.mangled_name


class _TranslationUnit(object):
  """Class wrapping clang's _TranslationUnit. Provides extra utilities."""

  def __init__(self, path, tu, limit_scan_depth=False):
    # type: (Text, cindex.TranslationUnit, bool) -> None
    self.path = path
    self.limit_scan_depth = limit_scan_depth
    self._tu = tu
    self._processed = False
    self.forward_decls = dict()
    self.functions = set()
    self.order = dict()
    self.defines = {}
    self.required_defines = set()
    self.types_to_skip = set()

  def _process(self):
    # type: () -> None
    """Walks the cursor tree and caches some for future use."""
    if not self._processed:
      # self.includes[self._tu.spelling] = (0, self._tu.cursor)
      self._processed = True
      # TODO(szwl): duplicates?
      # TODO(szwl): for d in translation_unit.diagnostics:, handle that

      for i, cursor in enumerate(self._walk_preorder()):
        # Workaround for issue#32
        # ignore all the cursors with kinds not implemented in python bindings
        try:
          cursor.kind
        except ValueError:
          continue
        # naive way to order types: they should be ordered when walking the tree
        if cursor.kind.is_declaration():
          self.order[cursor.hash] = i

        if (cursor.kind == cindex.CursorKind.MACRO_DEFINITION and
            cursor.location.file):
          self.order[cursor.hash] = i
          self.defines[cursor.spelling] = cursor

        # most likely a forward decl of struct
        if (cursor.kind == cindex.CursorKind.STRUCT_DECL and
            not cursor.is_definition()):
          self.forward_decls[Type(self, cursor.type)] = cursor
        if (cursor.kind == cindex.CursorKind.FUNCTION_DECL and
            cursor.linkage != cindex.LinkageKind.INTERNAL):
          if self.limit_scan_depth:
            if (cursor.location and cursor.location.file.name == self.path):
              self.functions.add(Function(self, cursor))
          else:
            self.functions.add(Function(self, cursor))

  def get_functions(self):
    # type: () -> Set[Function]
    if not self._processed:
      self._process()
    return self.functions

  def _walk_preorder(self):
    # type: () -> Gen
    for c in self._tu.cursor.walk_preorder():
      yield c

  def search_for_macro_name(self, cursor):
    # type: (cindex.Cursor) -> None
    """Searches for possible macro usage in constant array types."""
    tokens = list(t.spelling for t in cursor.get_tokens())
    try:
      for token in tokens:
        if token in self.defines and token not in self.required_defines:
          self.required_defines.add(token)
          self.search_for_macro_name(self.defines[token])
    except ValueError:
      return


class Analyzer(object):
  """Class responsible for analysis."""

  @staticmethod
  def process_files(input_paths, compile_flags, limit_scan_depth=False):
    # type: (Text, List[Text], bool) -> List[_TranslationUnit]
    """Processes files with libclang and returns TranslationUnit objects."""
    _init_libclang()

    tus = []
    for path in input_paths:
      tu = Analyzer._analyze_file_for_tu(
          path, compile_flags=compile_flags, limit_scan_depth=limit_scan_depth)
      tus.append(tu)
    return tus

  # pylint: disable=line-too-long
  @staticmethod
  def _analyze_file_for_tu(path,
                           compile_flags=None,
                           test_file_existence=True,
                           unsaved_files=None,
                           limit_scan_depth=False):
    # type: (Text, Optional[List[Text]], bool, Optional[Tuple[Text, Union[Text, IO[Text]]]], bool) -> _TranslationUnit
    """Returns Analysis object for given path."""
    compile_flags = compile_flags or []
    if test_file_existence and not os.path.isfile(path):
      raise IOError('Path {} does not exist.'.format(path))

    _init_libclang()
    index = cindex.Index.create()  # type: cindex.Index
    # TODO(szwl): hack until I figure out how python swig does that.
    # Headers will be parsed as C++. C libs usually have
    # '#ifdef __cplusplus extern "C"' for compatibility with c++
    lang = '-xc++' if not path.endswith('.c') else '-xc'
    args = [lang]
    args += compile_flags
    args.append('-I.')
    return _TranslationUnit(
        path,
        index.parse(
            path,
            args=args,
            unsaved_files=unsaved_files,
            options=_PARSE_OPTIONS),
        limit_scan_depth=limit_scan_depth)


class Generator(object):
  """Class responsible for code generation."""

  AUTO_GENERATED = ('// AUTO-GENERATED by the Sandboxed API generator.\n'
                    '// Edits will be discarded when regenerating this file.\n')

  GUARD_START = ('#ifndef {0}\n' '#define {0}')
  GUARD_END = '#endif  // {}'
  EMBED_INCLUDE = '#include "{}"'
  EMBED_CLASS = ('class {0}Sandbox : public ::sapi::Sandbox {{\n'
                 ' public:\n'
                 '  {0}Sandbox() : ::sapi::Sandbox({1}_embed_create()) {{}}\n'
                 '}};')

  def __init__(self, translation_units):
    # type: (List[cindex.TranslationUnit]) -> None
    """Initializes the generator.

    Args:
      translation_units: list of translation_units for analyzed files,
        facultative. If not given, then one is computed for each element of
        input_paths
    """
    self.translation_units = translation_units
    self.functions = None
    _init_libclang()

  def generate(self,
               name,
               function_names,
               namespace=None,
               output_file=None,
               embed_dir=None,
               embed_name=None):
    # pylint: disable=line-too-long
    # type: (Text, List[Text], Optional[Text], Optional[Text], Optional[Text], Optional[Text]) -> Text
    """Generates structures, functions and typedefs.

    Args:
      name: name of the class that will contain generated interface
      function_names: list of function names to export to the interface
      namespace: namespace of the interface
      output_file: path to the output file, used to generate header guards;
        defaults to None that does not generate the guard #include directives;
        defaults to None that causes to emit the whole file path
      embed_dir: path to directory with embed includes
      embed_name: name of the embed object

    Returns:
      generated interface as a string
    """
    related_types = self._get_related_types(function_names)
    forward_decls = self._get_forward_decls(related_types)
    functions = self._get_functions(function_names)
    related_types = [(t.stringify() + ';') for t in related_types]
    defines = self._get_defines()

    api = {
        'name': name,
        'functions': functions,
        'related_types': defines + forward_decls + related_types,
        'namespaces': namespace.split('::') if namespace else [],
        'embed_dir': embed_dir,
        'embed_name': embed_name,
        'output_file': output_file
    }
    return self.format_template(**api)

  def _get_functions(self, func_names=None):
    # type: (Optional[List[Text]]) -> List[Function]
    """Gets Function objects that will be used to generate interface."""
    if self.functions is not None:
      return self.functions
    self.functions = []
    # TODO(szwl): for d in translation_unit.diagnostics:, handle that
    for translation_unit in self.translation_units:
      self.functions += [
          f for f in translation_unit.get_functions()
          if not func_names or f.name in func_names
      ]
    # allow only nonmangled functions - C++ overloads are not handled in
    # code generation
    self.functions = [f for f in self.functions if not f.is_mangled()]

    # remove duplicates
    self.functions = list(set(self.functions))
    self.functions.sort(key=lambda x: x.name)
    return self.functions

  def _get_related_types(self, func_names=None):
    # type: (Optional[List[Text]]) -> List[Type]
    """Gets type definitions related to chosen functions.

    Types related to one function will land in the same translation unit,
    we gather the types, sort it and put as a sublist in types list.
    This is necessary as we can't compare types from two different translation
    units.

    Args:
      func_names: list of function names to take into consideration, empty means
        all functions.

    Returns:
      list of types in correct (ready to render) order
    """
    processed = set()
    fn_related_types = set()
    types = []
    types_to_skip = set()

    for f in self._get_functions(func_names):
      fn_related_types = f.get_related_types()
      types += sorted(r for r in fn_related_types if r not in processed)
      processed.update(fn_related_types)
      types_to_skip.update(f.translation_unit().types_to_skip)

    return [t for t in types if t not in types_to_skip]

  def _get_defines(self):
    # type: () -> List[Text]
    """Gets #define directives that appeared during TranslationUnit processing.

    Returns:
      list of #define string representations
    """

    def make_sort_condition(translation_unit):
      return lambda cursor: translation_unit.order[cursor.hash]

    result = []
    for tu in self.translation_units:
      tmp_result = []
      sort_condition = make_sort_condition(tu)
      for name in tu.required_defines:
        if name in tu.defines:
          define = tu.defines[name]
          tmp_result.append(define)
      for define in sorted(tmp_result, key=sort_condition):
        result.append('#define ' +
                      _stringify_tokens(define.get_tokens(), separator=' \\\n'))
    return result

  def _get_forward_decls(self, types):
    # type: (List[Type]) -> List[Text]
    """Gets forward declarations of related types, if present."""
    forward_decls = dict()
    result = []
    done = set()
    for tu in self.translation_units:
      forward_decls.update(tu.forward_decls)

      for t in types:
        if t in forward_decls and t not in done:
          result.append(_stringify_tokens(forward_decls[t].get_tokens()) + ';')
          done.add(t)

    return result

  def _format_function(self, f):
    # type: (Function) -> Text
    """Renders one function of the Api.

    Args:
      f: function object with information necessary to emit full function body

    Returns:
      filled function template
    """
    result = []
    result.append('  // {}'.format(f.original_definition))

    arguments = ', '.join(str(a) for a in f.arguments())
    result.append('  {} {}({}) {{'.format(f.result, f.name, arguments))
    result.append('    {} ret;'.format(f.result.mapped_type))

    argument_types = []
    for a in f.argument_types:
      if not a.is_sugared_ptr():
        argument_types.append(a.wrapped + ';')
    if argument_types:
      for arg in argument_types:
        result.append('    {}'.format(arg))

    call_arguments = f.call_arguments()
    if call_arguments:  # fake empty space to add ',' before first argument
      call_arguments.insert(0, '')
    result.append('')
    # For OSS, the macro below will be replaced.
    result.append('    SAPI_RETURN_IF_ERROR(sandbox_->Call("{}", &ret{}));'
                  ''.format(f.name, ', '.join(call_arguments)))

    return_status = 'return absl::OkStatus();'
    if f.result and not f.result.is_void():
      if f.result and f.result.is_sugared_enum():
        return_status = ('return static_cast<{}>'
                         '(ret.GetValue());').format(f.result.type)
      else:
        return_status = 'return ret.GetValue();'
    result.append('    {}'.format(return_status))
    result.append('  }')

    return '\n'.join(result)

  def format_template(self, name, functions, related_types, namespaces,
                      embed_dir, embed_name, output_file):
    # pylint: disable=line-too-long
    # type: (Text, List[Function], List[Text], List[Text], Text, Text, Text) -> Text
    # pylint: enable=line-too-long
    """Formats arguments into proper interface header file.

    Args:
      name: name of the Api - 'Test' will yield TestApi object
      functions: list of functions to generate
      related_types: types used in the above functions
      namespaces: list of namespaces to wrap the Api class with
      embed_dir: directory where the embedded library lives
      embed_name: name of embedded library
      output_file: interface output path - used in header guard generation

    Returns:
      generated header file text
    """
    result = [Generator.AUTO_GENERATED]

    header_guard = get_header_guard(output_file) if output_file else ''
    if header_guard:
      result.append(Generator.GUARD_START.format(header_guard))

    # Copybara transform results in the paths below.
    result.append('#include "absl/status/status.h"')
    result.append('#include "absl/status/statusor.h"')
    result.append('#include "sandboxed_api/sandbox.h"')
    result.append('#include "sandboxed_api/util/status_macros.h"')
    result.append('#include "sandboxed_api/vars.h"')

    if embed_name:
      embed_dir = embed_dir or ''
      result.append(
          Generator.EMBED_INCLUDE.format(
              os.path.join(embed_dir, embed_name) + '_embed.h'))

    if namespaces:
      result.append('')
      for n in namespaces:
        result.append('namespace {} {{'.format(n))

    if related_types:
      result.append('')
      for t in related_types:
        result.append(t)

    result.append('')

    if embed_name:
      result.append(
          Generator.EMBED_CLASS.format(name, embed_name.replace('-', '_')))

    result.append('class {}Api {{'.format(name))
    result.append(' public:')
    result.append('  explicit {}Api(::sapi::Sandbox* sandbox)'
                  ' : sandbox_(sandbox) {{}}'.format(name))
    result.append('  // Deprecated')
    result.append('  ::sapi::Sandbox* GetSandbox() const { return sandbox(); }')
    result.append('  ::sapi::Sandbox* sandbox() const { return sandbox_; }')

    for f in functions:
      result.append('')
      result.append(self._format_function(f))

    result.append('')
    result.append(' private:')
    result.append('  ::sapi::Sandbox* sandbox_;')
    result.append('};')
    result.append('')

    if namespaces:
      for n in reversed(namespaces):
        result.append('}}  // namespace {}'.format(n))

    if header_guard:
      result.append(Generator.GUARD_END.format(header_guard))

    result.append('')

    return '\n'.join(result)
