# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities related to plan protos."""

from typing import TypeVar

import tensorflow as tf
from fcp.artifact_building import tensor_utils
from fcp.protos import plan_pb2


_PlanT = TypeVar('_PlanT', plan_pb2.Plan, plan_pb2.ClientOnlyPlan)


# TODO(team): Remove in favor of save_from_checkpoint_op.
def write_checkpoint(sess, checkpoint_op, checkpoint_filename):
  """Writes from a CheckpointOp, without executing before/after restore ops."""
  if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
    raise ValueError('A CheckpointOp is required.')
  if (
      checkpoint_op
      and checkpoint_op.saver_def
      and checkpoint_op.saver_def.save_tensor_name
  ):
    sess.run(
        checkpoint_op.saver_def.save_tensor_name,
        {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
    )


# TODO(team): Remove in favor of restore_from_checkpoint_op.
def read_checkpoint(sess, checkpoint_op, checkpoint_filename):
  """Reads from a CheckpointOp, without executing before/after restore ops."""
  if not isinstance(checkpoint_op, plan_pb2.CheckpointOp):
    raise ValueError('A CheckpointOp is required.')
  if (
      checkpoint_op
      and checkpoint_op.saver_def
      and checkpoint_op.saver_def.restore_op_name
  ):
    sess.run(
        checkpoint_op.saver_def.restore_op_name,
        {checkpoint_op.saver_def.filename_tensor_name: checkpoint_filename},
    )


def convert_graphdef_to_flatbuffer(
    graph: tf.compat.v1.GraphDef,
    spec: plan_pb2.TensorflowSpec,
    guarantee_all_funcs_one_use: bool = False,
):
  """Converts a tf.Graph to a serialized TFLite model FlatBuffer."""

  def create_input(input_tensor):
    return (input_tensor.name, [item.size for item in input_tensor.shape.dim])

  inputs = [(spec.dataset_token_tensor_name, [])]
  for input_tensor in spec.input_tensor_specs:
    inputs.append(create_input(input_tensor))
  converter = tf.compat.v1.lite.TFLiteConverter(
      graph,
      input_tensors=None,
      output_tensors=None,
      input_arrays_with_shape=inputs,
      output_arrays=[item.name for item in spec.output_tensor_specs],
  )

  # pylint: disable=protected-access
  # Sets the control output node names. This is used when converting a tf.Graph
  # with no output tensors.
  converter._control_output_arrays = spec.target_node_names
  # Set this flag to true so that flatbuffer size can be reduced.
  converter._experimental_unfold_large_splat_constant = True
  # Exclude conversion metadata generation to reduce conversion time.
  converter.exclude_conversion_metadata = True
  converter.target_spec.supported_ops = [
      tf.lite.OpsSet.TFLITE_BUILTINS,
      tf.lite.OpsSet.SELECT_TF_OPS,
  ]
  converter._experimental_allow_all_select_tf_ops = True
  converter._experimental_guarantee_all_funcs_one_use = (
      guarantee_all_funcs_one_use
  )
  # Instructs the TF Lite converter to not eliminate Assert ops, since the
  # client code needs this op to verify result correctness.
  converter._experimental_preserve_assert_op = True
  # pylint: enable=protected-access
  converter.experimental_enable_resource_variables = True
  return converter.convert()


def generate_and_add_flat_buffer_to_plan(
    plan: _PlanT, forgive_tflite_conversion_failure=True
) -> _PlanT:
  """Generates and adds a TFLite model to the specified Plan.

  Note: This method mutates the plan argument.

  Args:
    plan: An input plan_pb2.Plan object.
    forgive_tflite_conversion_failure: If True, if TFLite conversion fails no
      exception will be raised and the Plan will be returned unmutated.

  Returns:
    The input Plan mutated to include a TFLite model when TFLite conversion
    succeeds, or the Plan without any mutation if TFLite conversion does not
    succeed.

  Raises:
    RuntimeError: if TFLite conversion fails and
      forgive_tflite_conversion_failure is set to False.
  """

  def convert(graph_def, tensorflow_spec, guarantee_all_funcs_one_use=False):
    stateful_partitioned_call_err = (
        "'tf.StatefulPartitionedCall' op is"
        + ' neither a custom op nor a flex op'
    )
    # Pack the TFLite flatbuffer into a BytesValue proto.
    try:
      return convert_graphdef_to_flatbuffer(
          graph_def, tensorflow_spec, guarantee_all_funcs_one_use
      )
    except Exception as e:  # pylint: disable=broad-except
      # Try to handle conversion errors and run converter again.
      if (
          stateful_partitioned_call_err in str(e)
          and not guarantee_all_funcs_one_use
      ):
        return convert(graph_def, tensorflow_spec, True)
      elif forgive_tflite_conversion_failure:
        return b''
      else:
        raise RuntimeError(
            f'Failure during TFLite conversion of the client graph: {str(e)}'
        ) from e

  if isinstance(plan, plan_pb2.Plan):
    client_graph_def = tensor_utils.import_graph_def_from_any(
        plan.client_graph_bytes
    )
    plan.client_tflite_graph_bytes = convert(
        client_graph_def, plan.phase[0].client_phase.tensorflow_spec
    )
  elif isinstance(plan, plan_pb2.ClientOnlyPlan):
    client_graph_def = tf.compat.v1.GraphDef.FromString(plan.graph)
    plan.tflite_graph = convert(client_graph_def, plan.phase.tensorflow_spec)
  else:
    raise NotImplementedError(f'Unsupported _PlanT {type(plan)}')
  return plan
