#!/usr/bin/env python3
#
# Copyright (C) 2024 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import warnings

"""Compares Host types against Guest types."""


def _is_size_required(atype):
  if atype['kind'] in ['incomplete', 'const', 'volatile', 'restrict', 'function']:
    return False
  return atype['kind'] != 'array' or not bool(atype.get('incomplete', 'false'))


def _compare_list_length(guest_list, host_list):
  if ((guest_list is None) != (host_list is None)):
    return False, False
  if (guest_list is None):
    return True, True
  if len(guest_list) == len(host_list):
    return True, False
  return False, False


class APIComparator(object):
  def __init__(self, guest_types, host_types, verbose=False):
    self.guest_types = guest_types
    self.host_types = host_types
    self.type_references = {}
    self.incompatible = set()
    self.compared = set()
    self.verbose = verbose

  def _add_incompatible(self, name_pair, reason):
    if (self.verbose):
      print(reason, name_pair)
    self.incompatible.add(name_pair)

  def _notice_type_reference(self, from_type_pair, to_type_pair):
    if (to_type_pair not in self.type_references):
      self.type_references[to_type_pair] = []
    self.type_references[to_type_pair].append(from_type_pair)

  def _compare_referenced_types(
      self, guest_type_name, host_type_name, ref_from_pair):
      # If referenced types are incompatible we will notice that
      # in propagate_incompatible().
      self.compare_types(
          guest_type_name, host_type_name)
      self._notice_type_reference(
          ref_from_pair, (guest_type_name, host_type_name))

  def _compare_size_and_align(self, guest_type, host_type, name_pair):
    if (guest_type['size'] != host_type['size']):
      self._add_incompatible(name_pair, "Types diff in size")
      return False
    # Objects in the guest memory should have at least the same
    # alignment as in host to be passed to host functions.
    # For objects created in host memory we assume that guest code
    # will never check their alignment.
    # TODO(b/232598137): Also we can skip alignment check for types which won't
    # be going to be addressed in memory. E.g. these are types not referenced
    # by pointers (not even by indirect pointers like pointer to a structure
    # containing the type).
    # TODO(b/232598137): DWARF generated by current version of clang does not
    # always provide information about alignment. Fix the dwarf or if this takes
    # too long fix nogrod to provide make educated guesses about alignments of
    # types. See also http://b/77671138
    if (('align' in guest_type and 'align' not in host_type) or
        ('align' not in guest_type and 'align' in host_type) or
        int(guest_type.get('align', '0')) < int(host_type.get('align', '0'))):
      self._add_incompatible(name_pair, "Types diff in align")
      return False
    return True

  def _compare_record_type_attrs(self, guest_type, host_type, name_pair):
    # In regular case polymorphic objects are created with v_table being in
    # host memory. When this happens it is safe to use polymorphic objects in
    # trampolines.
    # But sometimes polymorphic objects can be created with v_table in guest
    # memory. E. g. this occurs when client inherits from class defined in
    # native_bridge_support-api and creates a derived object. Such objects require special
    # translation thus are not subject for automatic trampolines generation.
    # We don't know whether object has v_table in host or in guest, so we
    # require custom trampolines for all of them.
    # Allow 'is_polymorphic' to be absent for backward-compatibility.
    # TODO(b/232598137): Correct code when all APIs are regenerated.
    if (guest_type.get('is_polymorphic', False) or
        host_type.get('is_polymorphic', False)):
      self._add_incompatible(name_pair, "Types diff due to polymorphism")
      return False

    is_cmp_ok, are_both_none = _compare_list_length(
        guest_type['fields'], host_type['fields'])
    if (not is_cmp_ok):
      self._add_incompatible(name_pair, "Types diff in fields list lengths")
      return False
    if (are_both_none):
      return True

    for i in range(len(guest_type['fields'])):
      guest_field = guest_type['fields'][i]
      host_field = host_type['fields'][i]
      if (guest_field['offset'] != host_field['offset']):
        self._add_incompatible(name_pair, "Types diff in field offset")
        return False
      self._compare_referenced_types(
          guest_field['type'], host_field['type'], name_pair)

    return True

  # TODO(b/232598137): Can we generate such trampolines?
  def _is_type_allowed_in_trampoline(self, type_name):
    # void is exception from other types of incomplete kind - it is supported.
    if (type_name == 'void'):
      return True
    type_desc = self.guest_types[type_name]
    if type_desc['kind'] in ['const', 'volatile', 'restrict']:
      type_desc = self.guest_types[type_desc['base_type']]
    return type_desc['kind'] not in \
        ['class', 'struct', 'union', 'incomplete', 'array']

  def _compare_trampoline_operand(
      self, operand_no, guest_name, host_name, name_pair):
    # We use 'x' trampoline operand type to involve this conversion.
    # Note: We don't make reference from function to this operand. So that
    # if these types are marked incompatible in context other than function
    # parameter, function itself still can be compatible (if everything
    # else is ok).
    # TODO(b/232598137): Define such compatible pairs in custom trampolines?
    if (guest_name == 'fp64' and host_name == 'fp96'):
      operands = self.guest_types[name_pair[0]].get(
          'long_double_conversion_operands', [])
      operands.append(operand_no)
      self.guest_types[name_pair[0]][
          'long_double_conversion_operands'] = operands
      return True
    if (not self._is_type_allowed_in_trampoline(guest_name)):
      self._add_incompatible(
          name_pair, "Types diff due to unallowed operand type")
      return False
    # If we accept pointers to functions we look on the functions themselves.
    # Note: function pointers embedded into data structures make them
    # incompatible, but GetTrampolineFunc knows how to wrap simple callbacks.
    guest_type = self.guest_types[guest_name]
    host_type = self.host_types[host_name]
    if (guest_type['kind'] == 'pointer' and
        self.guest_types[guest_type['pointee_type']]['kind'] == 'function' and
        host_type['kind'] == 'pointer' and
        self.host_types[host_type['pointee_type']]['kind'] == 'function'):
      guest_name = guest_type['pointee_type']
      host_name = host_type['pointee_type']
    self._compare_referenced_types(guest_name, host_name, name_pair)
    return True

  def _compare_function_type_attrs(self, guest_type, host_type, name_pair):
    if (not self._compare_trampoline_operand(
        0, guest_type['return_type'], host_type['return_type'], name_pair)):
      return False

    if (guest_type['has_variadic_args'] or host_type['has_variadic_args']):
      self._add_incompatible(name_pair, "Types diff due to variadic args")
      return False

    # Allow 'is_virtual_method' to be absent for backward-compatibility.
    # TODO(b/232598137): Correct code when all APIs are regenerated.
    if (guest_type.get('is_virtual_method', False) or
        host_type.get('is_virtual_method', False)):
      self._add_incompatible(name_pair, "Types diff due to virtual method")
      return False

    is_cmp_ok, are_both_none = _compare_list_length(
        guest_type['params'], host_type['params'])
    if (not is_cmp_ok):
      self._add_incompatible(name_pair, "Types diff in params lengths")
      return False
    if (are_both_none):
      return True

    for i in range(len(guest_type['params'])):
      if (not self._compare_trampoline_operand(
          i + 1, guest_type['params'][i], host_type['params'][i], name_pair)):
        return False

    return True

  def _compare_array_type_attrs(self, guest_type, host_type, name_pair):
    if (guest_type.get('incomplete', 'false') != host_type.get('incomplete', 'false')):
      self._add_incompatible(name_pair, "Types diff in incomleteness")
      return False
    self._compare_referenced_types(
        guest_type['element_type'], host_type['element_type'], name_pair)
    return True

  def _compare_pointer_type_attrs(self, guest_type, host_type, name_pair):
    if (self.guest_types[guest_type['pointee_type']]['kind'] == 'function'):
      self._add_incompatible(
          name_pair, "Types diff due to pointing to function")
      return False
    self._compare_referenced_types(
        guest_type['pointee_type'], host_type['pointee_type'], name_pair)
    return True

  def _is_compatibility_forced(self, name_pair):
    guest_type = self.guest_types[name_pair[0]]
    # Forcing compatible is only supported for the types with the same name.
    if (guest_type.get('force_compatible', False) and
        name_pair[0] == name_pair[1]):
      return True

    return name_pair[1] in guest_type.get('force_compatible_with', [])

  def _set_useful_force_compatible(self, guest_type_name):
    guest_type = self.guest_types[guest_type_name]
    guest_type['useful_force_compatible'] = True

  # Compare types internals and compare referenced types recursively.
  # References are remembered to propagate incompatibility later.
  # If types are incompatible internally return immediately as
  # referenced types are not of interest.
  #
  # To save us from loops in references between types we propagate
  # incompatibility from referenced types afterwards in
  # propagate_incompatible().
  def compare_types(
      self, guest_type_name, host_type_name):
    name_pair = (guest_type_name, host_type_name)

    # Prevent infinite recursion.
    if (name_pair in self.compared):
      return
    self.compared.add(name_pair)

    guest_type = self.guest_types[guest_type_name]
    host_type = self.host_types[host_type_name]

    if (guest_type['kind'] != host_type['kind']):
      self._add_incompatible(name_pair, "Types diff in kind")
      return

    kind = guest_type['kind']

    if (kind == 'dependent'):
      self._add_incompatible(name_pair, "Types depend on template parameters")
      return

    if (_is_size_required(guest_type)):
      if (not self._compare_size_and_align(
          guest_type, host_type, name_pair)):
        return

    if (kind in ['class', 'struct', 'union']):
      self._compare_record_type_attrs(
          guest_type, host_type, name_pair)
    elif (kind == 'function'):
      self._compare_function_type_attrs(
          guest_type, host_type, name_pair)
    elif ((kind == 'pointer') or (kind == 'reference') or (kind == 'rvalue_reference')):
      self._compare_pointer_type_attrs(
          guest_type, host_type, name_pair)
    elif (kind in ['const', 'volatile', 'restrict', 'atomic']):
      self._compare_referenced_types(
          guest_type['base_type'], host_type['base_type'], name_pair)
    elif (kind == 'array'):
      self._compare_array_type_attrs(
          guest_type, host_type, name_pair)

    # If more checks are added here, check return values of the
    # functions in previous if-elif block and return in case of
    # miscomparison.

    # Make sure we did not ignore types we shouldn't have.
    assert kind in ('array',
                    'atomic',
                    'char',
                    'class',
                    'const',
                    'incomplete',
                    'int',
                    'float',
                    'fp', # TODO: remove - this not used by new json generator.
                    'function',
                    'nullptr_t',
                    'pointer',
                    'reference',
                    'restrict',
                    'rvalue_reference',
                    'struct',
                    'union',
                    'UNSUPPORTED Enum', # TODO: what is this?
                    'UNSUPPORTED Atomic', # TODO: and this...
                    'volatile'), "Unknown type %s kind=%s, couldn't process" % (guest_type_name, kind)

  def _mark_references_as_incompatible(self, ref_to_type_pair):
    for ref_from_type_pair in self.type_references.get(ref_to_type_pair, []):
      # Go only through compatible types because incompatible types either are
      # already propagated or will be propagated in propagate_incompatible (if
      # they were in initial incompatible set).
      if (not self.are_types_compatible(ref_from_type_pair[0], ref_from_type_pair[1])):
        continue
      if (self._is_compatibility_forced(ref_from_type_pair)):
        if (self.verbose):
          print(("Not propagating incompatibility to types %s" + \
              " since they are forced to be compatible") % (ref_from_type_pair, ))
        self._set_useful_force_compatible(ref_from_type_pair[0])
        continue
      self._add_incompatible(
          ref_from_type_pair,
          "Incompatible type pair %s is referenced by" % (ref_to_type_pair,))
      self._mark_references_as_incompatible(ref_from_type_pair)

  def force_compatibility(self):
    for atype in self.guest_types:
      name_pair = (atype, atype)
      if (atype not in self.host_types):
        continue
      if (not self._is_compatibility_forced(name_pair)):
        continue
      if (self.are_types_compatible(atype, atype)):
        if (self.verbose):
          print(("Forcing compatibility for internally compatible types %s" + \
              " (maybe referencing other incompatible types)") % (name_pair, ))
        continue
      self._set_useful_force_compatible(atype)
      if (self.verbose):
        print("Forcing compatibility for", name_pair)
      self.incompatible.remove(name_pair)


  def propagate_incompatible(self):
    # Make a copy because we expand initial set.
    for type_pair in self.incompatible.copy():
      self._mark_references_as_incompatible(type_pair)

  def are_types_compatible(self, guest_type_name, host_type_name):
    return (guest_type_name, host_type_name) not in self.incompatible


def mark_incompatible_api_with_comparator(
    comparator, guest_symbols, host_symbols, verbose=False):
  for symbol, descr in guest_symbols.items():
    if (symbol not in host_symbols):
      continue
    comparator.compare_types(descr['type'], host_symbols[symbol]['type'])
  # Compare all types in case some of them are not referenced by
  # compatible symbols or other compatible types. We might want to use
  # the result of the analysis (e.g. check type compatibility expectation).
  for atype in comparator.guest_types:
    if (atype not in comparator.host_types):
      continue
    comparator.compare_types(atype, atype)

  # Do it before propagate_incompatible so that we don't propagate
  # those that are forced to be compatible.
  comparator.force_compatibility()

  comparator.propagate_incompatible()

  for symbol, descr in guest_symbols.items():
    if (symbol not in host_symbols):
      if (verbose):
        print("Symbol '%s' doesn't present on host" % symbol)
      descr['is_compatible'] = False
      continue

    descr['is_compatible'] = \
        comparator.are_types_compatible(
            descr['type'], host_symbols[symbol]['type'])

  for atype, descr in comparator.guest_types.items():
    descr['is_compatible'] = comparator.are_types_compatible(atype, atype)


def mark_incompatible_api(guest_api, host_api, verbose=False):
  comparator = APIComparator(guest_api['types'], host_api['types'], verbose)

  mark_incompatible_api_with_comparator(
      comparator, guest_api['symbols'], host_api['symbols'], verbose)


def _override_custom_type_properties(guest_api, custom_api):
  for custom_type, custom_descr in custom_api.get('types', {}).items():
    guest_api['types'][custom_type].update(custom_descr)


def _set_call_method_for_symbols(guest_api):
  for symbol, descr in guest_api['symbols'].items():
    type_name = descr['type']
    if guest_api['types'][type_name]['kind'] == 'function':
      descr['call_method'] = 'default'
    else:
      descr['call_method'] = 'do_not_call'


def _override_custom_symbol_properties(guest_api, custom_api):
  custom_config = custom_api.get('config', {})

  if custom_config.get('ignore_variables', False):
    for symbol, descr in guest_api['symbols'].items():
      type_name = descr['type']
      if guest_api['types'][type_name]['kind'] != 'function':
        descr['call_method'] = 'ignore'

  if custom_config.get('force_incompatible', False):
    for symbol, descr in guest_api['symbols'].items():
      descr['is_compatible'] = False

  for custom_symbol, custom_descr in custom_api['symbols'].items():
    # Some exported symbols may not present in headers
    # which are used for guest_api generation.
    if custom_symbol not in guest_api['symbols']:
      if not custom_config.get('ignore_non_present', False):
        guest_api['symbols'][custom_symbol] = custom_descr
    else:
      # This may override 'call_method' for function-type symbol.
      # But should not override 'is_compatible', which is only used
      # when symbol isn't present in guest_api.
      assert 'is_compatible' not in custom_descr, ('The symbol %s is already '
                                                   'compatible: remove the '
                                                   'override') % custom_symbol
      if 'is_custom_compatible' in custom_descr:
        custom_descr['is_compatible'] = custom_descr['is_custom_compatible']
      guest_api['symbols'][custom_symbol].update(custom_descr)

  if custom_config.get('ignore_non_custom', False):
    for symbol, descr in guest_api['symbols'].items():
      if symbol not in custom_api['symbols']:
        descr['call_method'] = 'ignore'


def _check_force_compatibility_was_useful(types):
  for atype, descr in types.items():
    if ('force_compatible' in descr) or ('force_compatible_with' in descr):
      if not descr.get('useful_force_compatible', False):
        warnings.warn("Forcing compatibility for type '%s' is redundant" % (atype))


def _check_expected_types_compatibility(types):
  for atype, descr in types.items():
    if 'expect_compatible' in descr:
      if descr['is_compatible'] != descr['expect_compatible']:
        raise Exception(
            ("Compatibility expectation for type '%s' is wrong:"
             ' is_compatible=%s, expect_compatible=%s') %
            (atype, descr['is_compatible'], descr['expect_compatible']))


def mark_incompatible_and_custom_api(guest_api, host_api, custom_api, verbose=False):
  # Type properties are used in api compatibility analysis.
  # So override them before the analysis.
  _override_custom_type_properties(guest_api, custom_api)
  mark_incompatible_api(guest_api, host_api, verbose=verbose)
  _check_force_compatibility_was_useful(guest_api['types'])

  _set_call_method_for_symbols(guest_api)
  # Custom symbol properties may override analysis results.
  _override_custom_symbol_properties(guest_api, custom_api)

  _check_expected_types_compatibility(guest_api['types'])
