# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for manipulating TensorFlow graph logic."""

from typing import Optional, Union

import tensorflow as tf
import tensorflow_federated as tff

from fcp.artifact_building import data_spec
from fcp.artifact_building import tensor_utils
from fcp.artifact_building import type_checks
from fcp.tensorflow import external_dataset
from tensorflow_federated.proto.v0 import computation_pb2

TfValue = Union[tf.Variable, tf.Tensor]
DatasetTensor = tf.Tensor
Argument = Union[TfValue, list[TfValue], DatasetTensor]
Args = Optional[Union[Argument, tuple[Argument, ...]]]

Result = Argument
MaybeSplitOutputs = Union[Result, tuple[Result, ...]]


EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX = 'example_selector'


def generate_example_selector_placeholders(
    type_spec: tff.Type,
    name_prefix: str,
):
  """Generates list of tff.compat.v1.placeholders for each leaf in a type spec.

  The order of the placeholders aligns with the order given by
  tff.structure.to_elements().

  Placeholders will be named by concatenating the name_prefix arg with the list
  of indexes at each level of the struct to get to the placeholder's leaf in the
  tff.Type.

  Args:
    type_spec: A type spec to infer the list of placeholders from. This is
      expected to be a tff.SequenceType or a tff.StructType, and if it is a
      tff.StructType, it is expected to be a tree of tff.StructTypes with
      tff.SequenceTypes at the leaves. This is expected to reflect the TFF type
      signature of the input client data.
    name_prefix: The name prefix that should be used when naming each
      placeholder.

  Returns:
    A list of tf.compat.v2.placeholders.
  """
  type_spec = tff.to_type(type_spec)
  type_checks.check_type(
      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
  )
  if type_spec.is_sequence():
    # Each client input is a sequence of serialized `tf.Example`s, which is why
    # the leaves of these TFF type signatures are sequences. Each input sequence
    # of `tf.Example`s requires a single `ExampleSelector` that determines that
    # stream of `tf.Example`s is selected from the data store, which is why we
    # only have a single placeholder for the `ExampleSelector`.
    return [tf.compat.v1.placeholder(tf.string, shape=[], name=name_prefix)]
  else:
    type_spec.check_struct()
    type_spec_elements = tff.structure.to_elements(type_spec)
    placeholders = []
    for element_index, (_, element_type) in enumerate(type_spec_elements):
      placeholders.extend(
          generate_example_selector_placeholders(
              element_type, f'{name_prefix}_{element_index}'
          )
      )
    return placeholders


def embed_data_logic(
    client_data_type: tff.Type,
    dataspec: Optional[data_spec.NestedDataSpec] = None,
) -> tuple[tf.Tensor, list[MaybeSplitOutputs], list[tf.Tensor]]:
  """Embeds the data logic into the current TensorFlow graph.

  Adds dataset ops to the current graph, using the custom `ExternalDataset`
  which returns a placeholder token. The initialization op and data values are
  also returned.

  Args:
    client_data_type: The TFF type signature of the input client data.
    dataspec: If provided, either an instance of `data_spec.DataSpec` or a
      nested structure of these that matches the structure of the first element
      of the input to the client work part of the computation.

  Returns:
    A `tuple` containing the following (in order):
      token_placeholder: A dataset token placeholder tensor
      data_values: A list of dataset output values
      example_selector_placeholders: A possibly empty list of placeholders used
        for passing in example selector information into the client graph. This
        list will be empty iff dataspec is supplied.

  Raises:
    ValueError: If the number of dataset output from one data source is not 1.
    ValueError: If a node exists in the graph already that contains a node with
      the same name as the example selector placeholders.
  """
  data_values = []
  # Embeds the token placeholder for the custom ExternalDataset op.
  token_placeholder = tf.compat.v1.placeholder(
      tf.string, shape=[], name='data_token'
  )

  example_selector_placeholders = []
  if dataspec is None:
    example_selector_placeholders = generate_example_selector_placeholders(
        client_data_type, EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX
    )
    # If the first placeholder does not have the expected prefix, then it is due
    # to other variables in the graph, likely created from the input
    # tff.Computation, having the special name. This check ensures that no other
    # variables use this special example selector placeholder name and we can
    # easily extract example selector placeholders in the generated artifact.
    if example_selector_placeholders and (
        not (
            example_selector_placeholders[0].name.startswith(
                f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}:'
            )
            or example_selector_placeholders[0].name.startswith(
                f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}_0'
            )
        )
    ):
      raise ValueError(
          'Graph already contains a placeholder with name '
          f'{EXAMPLE_SELECTOR_PLACEHOLDER_PREFIX}. Please '
          'avoid the use of this special name.'
      )
    data_sources = make_data_sources_without_dataspec(client_data_type)
    assert len(example_selector_placeholders) == len(data_sources)
  else:
    data_sources = make_data_sources_with_dataspec(client_data_type, dataspec)

  # Embeds data source computations into the current graph.
  for index, data_comp in enumerate(data_sources):
    data_comp_import_args = [token_placeholder]
    if example_selector_placeholders:
      data_comp_import_args.append(example_selector_placeholders[index])
    ds_values = import_tensorflow(
        'data_{}'.format(index), data_comp, data_comp_import_args
    )  # pytype: disable=wrong-arg-types
    if len(ds_values) != 1:
      raise ValueError(
          'Expected one dataset output from a data source, found {}.'.format(
              str(len(ds_values))
          )
      )
    data_values.extend(ds_values)

  return token_placeholder, data_values, example_selector_placeholders


def import_tensorflow(
    name: str,
    comp: tff.framework.ConcreteComputation,
    args: Args = None,
    split_outputs: bool = False,
    session_token_tensor: Optional[tf.Tensor] = None,
) -> MaybeSplitOutputs:
  """Imports a tensorflow computation into the current graph.

  Args:
    name: The string name to use as the graph import prefix.
    comp: An instance of `tff.framework.ConcreteComputation` with just the
      `tensorflow` section.
    args: Either a single argument, a tuple of arguments, or None. An argument
      must be either: - a Python `list` containing either tensors or variables,
      or - a single variant tensor representing a dataset input.
    split_outputs: Whether to unpack the result tuple into a Python tuple. If
      `True`, `import_tensorflow` will return a tuple with multiple result
      objects, corresponding to the return elements in the type signature of
      `comp`. Notice that the return type signature of `comp` must be a tuple in
      this case. If `False`, `import_tensorflow` will return the entire result
      in a flattened form as a single Python result object. Each Python result
      object, similar to the argumens in `args`, will be either a Python `list`
      of variant tensors or a singleton Python list containing only the dataset
      variant tensor.
    session_token_tensor: A tensor in the current graph containing the "session
      token" of the TensorFlow being imported. This is useful for passing a
      session-global identifier into the graph for use with ops like
      `ServeSlices` and `ExternalDataset` that take in a token which references
      session-global state.

  Returns:
    One of:
      - A single result (Python `list` of variable value or variant tensors) if
        `split_outputs` is `False`.
      - A Python `tuple` of such results, if `split_outputs` is `True`.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  type_checks.check_type(name, str, name='name')
  type_checks.check_type(comp, tff.framework.ConcreteComputation, name='comp')
  type_checks.check_type(split_outputs, bool, name='split_outputs')

  comp_proto = tff.framework.ConcreteComputation.get_proto(comp)
  type_checks.check_type(
      comp_proto, computation_pb2.Computation, name='comp_proto'
  )

  which_comp = comp_proto.WhichOneof('computation')
  if which_comp != 'tensorflow':
    raise TypeError(
        'Expected a TensorFlow computation, found {}.'.format(which_comp)
    )
  if args is None:
    input_map = None
  elif isinstance(args, tuple):
    which_binding = comp_proto.tensorflow.parameter.WhichOneof('binding')
    if which_binding != 'struct':
      raise TypeError(
          'Expected a struct binding with a struct of args, found {}.'.format(
              which_binding
          )
      )
    input_map = {}
    for index, arg in enumerate(args):
      input_map.update(
          create_tensor_map(
              comp_proto.tensorflow.parameter.struct.element[index], arg
          )
      )
  else:
    input_map = create_tensor_map(comp_proto.tensorflow.parameter, args)
  if input_map is not None:
    # Add remappings for all potential control dependencies in the graph as
    # well. Since `tf.graph_util.import_graph_def` input map works on the tensor
    # (not graph node) level, we must handle this case also.
    def control_dep_name(name: str) -> str:
      if name.startswith('^'):
        return name
      node_name = name.split(':', maxsplit=1)[0]
      return f'^{node_name}'

    input_map.update(
        {
            control_dep_name(k): control_dep_name(v.name)
            for k, v in input_map.items()
            if not k.startswith('^')
        }
    )
  input_map = {} if input_map is None else input_map
  if (
      session_token_tensor is not None
      and comp_proto.tensorflow.session_token_tensor_name
  ):
    input_map[comp_proto.tensorflow.session_token_tensor_name] = (
        session_token_tensor
    )
  if split_outputs:
    return_elements = []
    subset_sizes = []
    which_binding = comp_proto.tensorflow.result.WhichOneof('binding')
    if which_binding != 'struct':
      raise TypeError(
          'If `split_outputs` is `True`, the result of the computation we are '
          'importing must be a `struct`; found {}.'.format(which_binding)
      )
    for binding in comp_proto.tensorflow.result.struct.element:
      tensor_names = _list_tensor_names_in_binding(binding)
      return_elements.extend(tensor_names)
      subset_sizes.append(len(tensor_names))
  else:
    return_elements = _list_tensor_names_in_binding(
        comp_proto.tensorflow.result
    )
    subset_sizes = [len(return_elements)]

  graph_def = tensor_utils.import_graph_def_from_any(
      comp_proto.tensorflow.graph_def
  )

  # We will be importing multiple GraphDefs into the server or client graphs.
  # These individual graphs may have identifical `shared_name` attributes on
  # variable ops, which causes the runtime to reference the same resource, which
  # is highly undesired. We must uniquify the names before importing.
  def uniquify_shared_names(
      graph_def: tf.compat.v1.GraphDef, suffix: bytes
  ) -> tf.compat.v1.GraphDef:
    for x in graph_def.node:
      shared_name = x.attr.get('shared_name')
      if shared_name is not None:
        if not shared_name.s:
          # Encountered an empty string shared name, avoid creating a shared
          # name that starts with an underscore (not allowed by TF).
          shared_name.s = b'None'
        shared_name.s += b'_' + suffix
    return graph_def

  uniquified_graph_def = uniquify_shared_names(
      graph_def, suffix=name.encode('utf-8')
  )
  if comp_proto.tensorflow.initialize_op:
    uniquified_graph_def = add_control_deps_for_init_op(
        uniquified_graph_def, comp_proto.tensorflow.initialize_op
    )
  import_result = tf.graph_util.import_graph_def(
      uniquified_graph_def,
      input_map=input_map,
      return_elements=return_elements,
      name=name,
  )

  if split_outputs:
    subsets = []
    offset = 0
    for subset_size in subset_sizes:
      next_offset = offset + subset_size
      subsets.append(import_result[offset:next_offset])
      offset = next_offset
    results = tuple(subsets)
  else:
    results = import_result[: subset_sizes[0]]
  return results


def _get_deps_for_graph_node(
    graph_def: tf.compat.v1.GraphDef, node_name: str
) -> set[str]:
  """Returns the set of node names that a node named `node_name` depends on.

  Note that this function does not work for nodes in the function library.

  Args:
    graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
    node_name: The node name, a string.

  Returns:
    An instance of `set()` containing string names of the nodes `node_name`
    depends on in graph_def.

  Raises:
    TypeError: If either argument is of the wrong type.
  """
  type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
  type_checks.check_type(node_name, str, name='node_name')
  input_map = {}
  for node in graph_def.node:
    input_map[node.name] = set(tensor_utils.bare_name(x) for x in node.input)
  dependencies = set()
  initial_singleton = set([node_name])
  nodes_to_process = initial_singleton
  while nodes_to_process:
    dependencies.update(nodes_to_process)
    nodes_to_process = set.union(
        *[input_map[name] for name in nodes_to_process]
    ).difference(dependencies)
  return dependencies.difference(initial_singleton)


def add_control_deps_for_init_op(
    graph_def: tf.compat.v1.GraphDef, init_op: str
) -> tf.compat.v1.GraphDef:
  """Adds control deps on `init_op` to nodes in GraphDef.

  Note that control deps are not added to any of the ancestors of `init_op`
  (which would result in a control dep cycle) and control deps are not added to
  any nodes in the function library of a GraphDef.

  Args:
    graph_def: The input graph, an instance of `tf.compat.v1.GraphDef`.
    init_op: The init op name, a string.

  Returns:
    The updated graph, an instance of `tf.compat.v1.GraphDef`.

  Raises:
    TypeError: If either argument is of the wrong type.
  """
  type_checks.check_type(graph_def, tf.compat.v1.GraphDef, name='graph_def')
  type_checks.check_type(init_op, str, name='init_op')
  init_op_str = tensor_utils.bare_name(init_op)
  init_op_control_dep = '^{}'.format(init_op_str)
  deps = _get_deps_for_graph_node(graph_def, init_op_str).union(
      set([init_op_str])
  )
  new_graph_def = tf.compat.v1.GraphDef()
  new_graph_def.CopyFrom(graph_def)
  for new_node in new_graph_def.node:
    if new_node.name not in deps:
      node_inputs = new_node.input
      if init_op_control_dep not in node_inputs:
        new_node.input.append(init_op_control_dep)
  return new_graph_def


def create_tensor_map(
    binding: computation_pb2.TensorFlow.Binding,
    arg: list[Union[tf.Tensor, tf.Variable]],
) -> dict[str, tf.Tensor]:
  """Creates a `dict` mapping tensor names in the binding to tensors in `arg`.

  Args:
    binding: An instance of `computation_pb2.TensorFlow.Binding`.
    arg: Either a singleton Python `list` with variant tensor in case of a
      sequence binding, or a Python `list` of tensors or resource variables
      otherwise for a tuple binding.

  Returns:
    An instance of Python `dict` with the map as specified above.

  Raises:
    TypeError: If the argument types are incorrect.
    ValueError: If the arguments are malformed (e.g., multiple variant tensors).
  """
  type_checks.check_type(
      binding, computation_pb2.TensorFlow.Binding, name='binding'
  )
  type_checks.check_type(arg, list, name='arg')
  tensor_names_in_binding = _list_tensor_names_in_binding(binding)
  which_binding = binding.WhichOneof('binding')
  if which_binding == 'sequence':
    if (len(tensor_names_in_binding) != 1) or (len(arg) != 1):
      raise ValueError('Multiple variant tensors found.')
    variant_tensor_name = tensor_names_in_binding[0]
    arg = arg[0]
    if not tf.is_tensor(arg):
      raise TypeError('Expected a tensor, found {!r}.'.format(type(arg)))
    if arg.dtype != tf.variant:
      raise TypeError('Expected `tf.variant`, found {!r}.'.format(arg.dtype))
    return {variant_tensor_name: arg}
  else:
    return {
        k: v.read_value() if hasattr(v, 'read_value') else v
        for k, v in zip(tensor_names_in_binding, arg)
    }


def _validate_data_comp(data_comp: tff.Computation, type_spec: tff.Type):
  type_checks.check_type(data_comp.type_signature, tff.FunctionType)
  if not type_spec.is_assignable_from(data_comp.type_signature.result):
    type_mismatch_string = tff.types.type_mismatch_error_message(
        type_spec,
        data_comp.type_signature.result,
        tff.types.TypeRelation.ASSIGNABLE,
    )
    raise TypeError(
        'The data source constructed with the supplied dataspec returns data '
        'which does not match type of request. Details of the mismatch:\n'
        + type_mismatch_string
    )


def make_data_sources_with_dataspec(
    type_spec: tff.Type, ds: data_spec.NestedDataSpec
) -> list[tff.Computation]:
  """Creates a list of computations that feed data into the graph using specified example selectors.

  The computations use the custom ExternalDataset op to feed in example data.
  The computations will expect one input:
    -- A token specifying where the data store is on the device.
  Example selectors that describes what data to take from the on-device data
  store will be hard-coded into the computations.

  Args:
    type_spec: The TFF type signature of the output, which must be either a
      sequence, or a named tuple of sequences.
    ds: Either a single `data_spec.DataSpec`, or a nested structure of these,
      made up of Python containers, that exactly matches the structure of the
      `type_spec`.

  Returns:
    A list of `tff.Computation`s, each of which accepts a single `string`-typed
    tensor as input (the token for the ExternalDataset op) and returns a
    sequence as output (with the result that matches the corresponding part of
    `type_spec`). The computations appear on the list in a depth-first order
    (matching exactly the convention used in the
    `_list_tensor_names_in_binding()` method below).

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  assert ds
  type_spec = tff.to_type(type_spec)
  type_checks.check_type(
      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
  )
  if type_spec.is_sequence():
    type_checks.check_type(ds, data_spec.DataSpec)
    assert isinstance(ds, data_spec.DataSpec)
    assert ds.example_selector_proto is not None
    sel_bytes = ds.example_selector_proto.SerializeToString()

    @tff.tf_computation(tf.string)
    def data_comp(token):
      """The data source computation.

      Args:
        token: The token placeholder tensor (`tf.string`).

      Returns:
        An instance of `tf.data.Dataset`.
      """
      if ds.preprocessing_fn is not None:
        processed_ds = ds.preprocessing_fn(
            external_dataset.ExternalDataset(token=token, selector=sel_bytes)
        )
      else:
        processed_ds = external_dataset.ExternalDataset(
            token=token, selector=sel_bytes
        )

      if 'Dataset' not in type(processed_ds).__name__:
        raise TypeError(
            'The preprocessing function returned an unrecognized non-dataset '
            'type {!r}.'.format(type(processed_ds))
        )
      return processed_ds

    _validate_data_comp(data_comp, type_spec)
    return [data_comp]
  else:
    type_spec.check_struct()
    if isinstance(ds, data_spec.DataSpec):
      raise TypeError(
          'Expected nested structure of `DataSpec`s conforming to '
          f'the structure of the type {type_spec}. '
          'Found single `DataSpec` instead.'
      )
    ds = tff.structure.from_container(ds)
    assert isinstance(ds, tff.structure.Struct)
    type_spec_elements = tff.structure.to_elements(type_spec)
    data_spec_elements = tff.structure.to_elements(ds)
    type_spec_element_names = [str(k) for k, _ in type_spec_elements]
    data_spec_element_names = [str(k) for k, _ in data_spec_elements]
    if type_spec_element_names != data_spec_element_names:
      raise TypeError(
          'Type vs. data spec elements names mismatch: {} vs. {}.'.format(
              str(type_spec_element_names), str(data_spec_element_names)
          )
      )
    elements = []
    for element_index, (_, element_type) in enumerate(type_spec_elements):
      elements.extend(
          make_data_sources_with_dataspec(element_type, ds[element_index])
      )
    return elements


def make_data_sources_without_dataspec(type_spec) -> list[tff.Computation]:
  """Creates a list of computations that feed data into the graph.

  The computations use the custom ExternalDataset op to feed in example data.
  The computations will expect two inputs:
    -- A token specifying where the data store is on the device.
    -- An example selector that describes what data to take from the on-device
      data store.

  Args:
    type_spec: The TFF type signature of the output, which must be either a
      sequence, or a named tuple of sequences.

  Returns:
    A list of `tff.Computation`s, each of which accepts a single `string`-typed
    tensor as input (the token for the ExternalDataset op) and returns a
    sequence as output (with the result that matches the corresponding part of
    `type_spec`). The computations appear on the list in a depth-first order
    (matching exactly the convention used in the
    `_list_tensor_names_in_binding()` method below).

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  type_spec = tff.to_type(type_spec)
  type_checks.check_type(
      type_spec, (tff.SequenceType, tff.StructType), name='type_spec'
  )
  if type_spec.is_sequence():

    @tff.tf_computation(tf.string, tf.string)
    def data_comp(token, example_selector):
      """The data source computation.

      Args:
        token: The token placeholder tensor (`tf.string`).
        example_selector: The example selector placeholder tensor (`tf.string`).

      Returns:
        An instance of `tf.data.Dataset`.
      """
      processed_ds = external_dataset.ExternalDataset(
          token=token, selector=example_selector
      )

      if 'Dataset' not in type(processed_ds).__name__:
        raise TypeError(
            'The preprocessing function returned an unrecognized non-dataset '
            'type {!r}.'.format(type(processed_ds))
        )
      return processed_ds

    _validate_data_comp(data_comp, type_spec)
    return [data_comp]
  else:  # type_spec is a struct.
    type_spec.check_struct()
    type_spec_elements = tff.structure.to_elements(type_spec)
    elements = []
    for _, element_type in type_spec_elements:
      elements.extend(make_data_sources_without_dataspec(element_type))
    return elements


def _list_tensor_names_in_binding(
    binding: computation_pb2.TensorFlow.Binding,
) -> list[str]:
  """Returns a flat Python list of tensor names that appear in the `binding`.

  Args:
    binding: An instance of `computation_pb2.TensorFlow.Binding` in which any
      sequence bindings must contain variant tensors.

  Returns:
    A list of `str` instances with tensor names that appear in `binding` in the
    order in which they appear in the depth-first traversal of the potentially
    nested binding structure.

  Raises:
    TypeError: If the arguments are of the wrong types.
  """
  type_checks.check_type(binding, computation_pb2.TensorFlow.Binding)
  which_binding = binding.WhichOneof('binding')
  if which_binding == 'tensor':
    return [str(binding.tensor.tensor_name)]
  elif which_binding == 'struct':
    result = []
    for element in binding.struct.element:
      result.extend(_list_tensor_names_in_binding(element))
    return result
  elif which_binding == 'sequence':
    which_sequence = binding.sequence.WhichOneof('binding')
    if which_sequence != 'variant_tensor_name':
      raise TypeError(
          'Expected a variant tensor in sequence binding, found {}.'.format(
              which_sequence
          )
      )
    return [binding.sequence.variant_tensor_name]
  else:
    raise TypeError('Unexpected type of binding {}.'.format(which_binding))
