# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either expresus or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MaterializableValueReference that reads from a TensorFlow checkpoint."""

from typing import Any, Optional
import uuid

import tensorflow as tf
import tensorflow_federated as tff


class CheckpointTensorReference(tff.program.MaterializableValueReference):
  """A reference to a tensor in a TF checkpoint file."""

  def __init__(self, tensor_name: str, dtype: tf.DType, shape: Any,
               checkpoint_future: tff.async_utils.SharedAwaitable):
    """Constructs a new CheckpointTensorReference object.

    Args:
      tensor_name: The name of the tensor in the TF checkpoint.
      dtype: The type of the tensor.
      shape: The shape of the tensor, expressed as a value convertible to
        `tf.TensorShape`.
      checkpoint_future: A `tff.async_utils.SharedAwaitable` that resolves to
        the TF checkpoint bytes once they're available.
    """
    self._tensor_name = tensor_name
    self._type_signature = tff.TensorType(dtype, shape)
    self._checkpoint_future = checkpoint_future
    self._tensor: Optional[tf.Tensor] = None

  @property
  def type_signature(self) -> tff.Type:
    return self._type_signature

  async def get_value(self) -> tff.program.MaterializedValue:
    if self._tensor is None:
      checkpoint = await self._checkpoint_future
      # Write to a file in TensorFlow's RamFileSystem to avoid disk I/O.
      tmpfile = f'ram://{uuid.uuid4()}.ckpt'
      with tf.io.gfile.GFile(tmpfile, 'wb') as f:
        f.write(checkpoint)
      try:
        self._tensor = tf.raw_ops.RestoreV2(
            prefix=tmpfile,
            tensor_names=[self._tensor_name],
            shape_and_slices=[''],
            dtypes=[self._type_signature.dtype])[0]
      finally:
        tf.io.gfile.remove(tmpfile)

    try:
      return self._tensor.numpy()
    except AttributeError as e:
      raise ValueError('get_value is only supported in eager mode.') from e
