# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper methods for proto creation logic."""

from typing import Optional

import tensorflow as tf
import tensorflow_federated as tff

from fcp.artifact_building import tensor_utils
from fcp.artifact_building import type_checks
from fcp.protos import plan_pb2


def make_tensor_spec_from_tensor(
    t: tf.Tensor, shape_hint: Optional[tf.TensorShape] = None
) -> tf.TensorSpec:
  """Creates a `TensorSpec` from Tensor w/ optional shape hint.

  Args:
    t: A `tf.Tensor` instance to be used to create a `TensorSpec`.
    shape_hint: A `tf.TensorShape` that provides a fully defined shape in the
      case that `t` is partially defined. If `t` has a fully defined shape,
      `shape_hint` is ignored. `shape_hint` must be compatible with the
      partially defined shape of `t`.

  Returns:
    A `tf.TensorSpec` instance corresponding to the input `tf.Tensor`.

  Raises:
    NotImplementedError: If the input `tf.Tensor` type is not supported.
    TypeError: if `shape_hint` is not `None` and is incompatible with the
      runtime shape of `t`.
  """
  if not tf.is_tensor(t):
    raise NotImplementedError(
        'Cannot handle type {t}: {v}'.format(t=type(t), v=t)
    )
  derived_shape = tf.TensorShape(t.shape)
  if not derived_shape.is_fully_defined() and shape_hint is not None:
    if derived_shape.is_compatible_with(shape_hint):
      shape = shape_hint
    else:
      raise TypeError(
          'shape_hint is not compatible with tensor ('
          f'{shape_hint} vs {derived_shape})'
      )
  else:
    shape = derived_shape
  return tf.TensorSpec(shape, t.dtype, name=t.name)


def make_measurement(
    t: tf.Tensor, name: str, tff_type: tff.types.TensorType
) -> plan_pb2.Measurement:
  """Creates a `plan_pb.Measurement` descriptor for a tensor.

  Args:
    t: A tensor to create the measurement for.
    name: The name of the measurement (e.g. 'server/loss').
    tff_type: The `tff.Type` of the measurement.

  Returns:
    An instance of `plan_pb.Measurement`.

  Raises:
    ValueError: If the `dtype`s or `shape`s of the provided tensor and TFF type
      do not match.
  """
  type_checks.check_type(tff_type, tff.types.TensorType)
  if tff_type.dtype != t.dtype:
    raise ValueError(
        f'`tff_type.dtype`: {tff_type.dtype} does not match '
        f"provided tensor's dtype: {t.dtype}."
    )
  if tff_type.shape.is_fully_defined() and t.shape.is_fully_defined():
    if tff_type.shape.as_list() != t.shape.as_list():
      raise ValueError(
          f'`tff_type.shape`: {tff_type.shape} does not match '
          f"provided tensor's shape: {t.shape}."
      )
  return plan_pb2.Measurement(
      read_op_name=t.name,
      name=name,
      tff_type=tff.types.serialize_type(tff_type).SerializeToString(),
  )


def make_metric(v: tf.Variable, stat_name_prefix: str) -> plan_pb2.Metric:
  """Creates a `plan_pb.Metric` descriptor for a resource variable.

  The stat name is formed by stripping the leading `..../` prefix and any
  colon-based suffix.

  Args:
    v: A variable to create the metric descriptor for.
    stat_name_prefix: The prefix (string) to use in formulating a stat name,
      excluding the trailing slash `/` (added automatically).

  Returns:
    An instance of `plan_pb.Metric` for `v`.

  Raises:
    TypeError: If the arguments are of the wrong types.
    ValueError: If the arguments are malformed (e.g., no leading name prefix).
  """
  type_checks.check_type(stat_name_prefix, str, name='stat_name_prefix')
  if not hasattr(v, 'read_value'):
    raise TypeError('Expected a resource variable, found {!r}.'.format(type(v)))
  bare_name = tensor_utils.bare_name(v.name)
  if '/' not in bare_name:
    raise ValueError(
        'Expected a prefix in the name, found none in {}.'.format(bare_name)
    )
  stat_name = '{}/{}'.format(
      stat_name_prefix, bare_name[(bare_name.find('/') + 1) :]
  )
  return plan_pb2.Metric(variable_name=v.read_value().name, stat_name=stat_name)
