# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for checkpoint_utils."""

import collections
import os
import typing
from typing import Any

from absl.testing import absltest
from absl.testing import parameterized

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from google.protobuf import any_pb2
from fcp.artifact_building import checkpoint_utils
from fcp.protos import plan_pb2


class CheckpointUtilsTest(tf.test.TestCase, parameterized.TestCase):

  def _assert_variable_functionality(
      self, test_vars: list[tf.Variable], test_value_to_save: Any = 10
  ):
    self.assertIsInstance(test_vars, list)
    initializer = tf.compat.v1.global_variables_initializer()
    for test_variable in test_vars:
      with self.test_session() as session:
        session.run(initializer)
        session.run(test_variable.assign(test_value_to_save))
        self.assertEqual(session.run(test_variable), test_value_to_save)

  def test_create_server_checkpoint_vars_and_savepoint_succeeds_state_vars(
      self,
  ):
    with tf.Graph().as_default():
      state_vars, _, _, savepoint = (
          checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
              server_state_type=tff.to_type([('foo1', tf.int32)]),
              server_metrics_type=tff.to_type([('bar2', tf.int32)]),
              write_metrics_to_checkpoint=True,
          )
      )
      self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
      self._assert_variable_functionality(state_vars)

  def test_create_server_checkpoint_vars_and_savepoint_succeeds_metadata_vars(
      self,
  ):
    def additional_checkpoint_metadata_var_fn(
        state_vars, metrics_vars, write_metrics_to_checkpoint
    ):
      del state_vars, metrics_vars, write_metrics_to_checkpoint
      return [tf.Variable(initial_value=b'dog', name='metadata')]

    with tf.Graph().as_default():
      _, _, metadata_vars, savepoint = (
          checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
              server_state_type=tff.to_type([('foo3', tf.int32)]),
              server_metrics_type=tff.to_type([('bar1', tf.int32)]),
              additional_checkpoint_metadata_var_fn=(
                  additional_checkpoint_metadata_var_fn
              ),
              write_metrics_to_checkpoint=True,
          )
      )
      self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
      self._assert_variable_functionality(
          metadata_vars, test_value_to_save=b'cat'
      )

  def test_create_server_checkpoint_vars_and_savepoint_succeeds_metrics_vars(
      self,
  ):
    with tf.Graph().as_default():
      _, metrics_vars, _, savepoint = (
          checkpoint_utils.create_server_checkpoint_vars_and_savepoint(
              server_state_type=tff.to_type([('foo2', tf.int32)]),
              server_metrics_type=tff.to_type([('bar3', tf.int32)]),
              write_metrics_to_checkpoint=True,
          )
      )
      self.assertIsInstance(savepoint, plan_pb2.CheckpointOp)
      self._assert_variable_functionality(metrics_vars)

  def test_tff_type_to_dtype_list_as_expected(self):
    tff_type = tff.FederatedType(
        tff.StructType([('foo', tf.int32), ('bar', tf.string)]), tff.SERVER
    )
    expected_dtype_list = [tf.int32, tf.string]
    self.assertEqual(
        checkpoint_utils.tff_type_to_dtype_list(tff_type), expected_dtype_list
    )

  def test_tff_type_to_dtype_list_type_error(self):
    list_type = [tf.int32, tf.string]
    with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
      checkpoint_utils.tff_type_to_dtype_list(list_type)

  def test_tff_type_to_tensor_spec_list_as_expected(self):
    tff_type = tff.FederatedType(
        tff.StructType(
            [('foo', tf.int32), ('bar', tff.TensorType(tf.string, shape=[1]))]
        ),
        tff.SERVER,
    )
    expected_tensor_spec_list = [
        tf.TensorSpec([], tf.int32),
        tf.TensorSpec([1], tf.string),
    ]
    self.assertEqual(
        checkpoint_utils.tff_type_to_tensor_spec_list(tff_type),
        expected_tensor_spec_list,
    )

  def test_tff_type_to_tensor_spec_list_type_error(self):
    list_type = [tf.int32, tf.string]
    with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
      checkpoint_utils.tff_type_to_tensor_spec_list(list_type)

  def test_pack_tff_value_with_tensors_as_expected(self):
    tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)])
    value_list = [
        tf.constant(1, dtype=tf.int32),
        tf.constant('bla', dtype=tf.string),
    ]
    expected_packed_structure = tff.structure.Struct([
        ('foo', tf.constant(1, dtype=tf.int32)),
        ('bar', tf.constant('bla', dtype=tf.string)),
    ])
    self.assertEqual(
        checkpoint_utils.pack_tff_value(tff_type, value_list),
        expected_packed_structure,
    )

  def test_pack_tff_value_with_federated_server_tensors_as_expected(self):
    # This test must create a type that has `StructType`s nested under the
    # `FederatedType` to cover testing that tff.structure.pack_sequence_as
    # package correctly descends through the entire type tree.
    tff_type = tff.to_type(
        collections.OrderedDict(
            foo=tff.FederatedType(tf.int32, tff.SERVER),
            # Some arbitrarily deep nesting to ensure full traversals.
            bar=tff.FederatedType([(), ([tf.int32], tf.int32)], tff.SERVER),
        )
    )
    value_list = [tf.constant(1), tf.constant(2), tf.constant(3)]
    expected_packed_structure = tff.structure.from_container(
        collections.OrderedDict(
            foo=tf.constant(1), bar=[(), ([tf.constant(2)], tf.constant(3))]
        ),
        recursive=True,
    )
    self.assertEqual(
        checkpoint_utils.pack_tff_value(tff_type, value_list),
        expected_packed_structure,
    )

  def test_pack_tff_value_with_unmatched_input_sizes(self):
    tff_type = tff.StructType([('foo', tf.int32), ('bar', tf.string)])
    value_list = [tf.constant(1, dtype=tf.int32)]
    with self.assertRaises(ValueError):
      checkpoint_utils.pack_tff_value(tff_type, value_list)

  def test_pack_tff_value_with_tff_type_error(self):
    @tff.federated_computation
    def fed_comp():
      return tff.federated_value(0, tff.SERVER)

    tff_function_type = fed_comp.type_signature
    value_list = [tf.constant(1, dtype=tf.int32)]
    with self.assertRaisesRegex(TypeError, 'to be an instance of type'):
      checkpoint_utils.pack_tff_value(tff_function_type, value_list)

  def test_variable_names_from_structure_with_tensor_and_no_name(self):
    names = checkpoint_utils.variable_names_from_structure(tf.constant(1.0))
    self.assertEqual(names, ['v'])

  def test_variable_names_from_structure_with_tensor(self):
    names = checkpoint_utils.variable_names_from_structure(
        tf.constant(1.0), 'test_name'
    )
    self.assertEqual(names, ['test_name'])

  def test_variable_names_from_structure_with_named_tuple_type_and_no_name(
      self,
  ):
    names = checkpoint_utils.variable_names_from_structure(
        tff.structure.Struct([
            ('a', tf.constant(1.0)),
            (
                'b',
                tff.structure.Struct(
                    [('c', tf.constant(True)), ('d', tf.constant(0.0))]
                ),
            ),
        ])
    )
    self.assertEqual(names, ['v/a', 'v/b/c', 'v/b/d'])

  def test_variable_names_from_structure_with_named_struct(self):
    names = checkpoint_utils.variable_names_from_structure(
        tff.structure.Struct([
            ('a', tf.constant(1.0)),
            (
                'b',
                tff.structure.Struct(
                    [('c', tf.constant(True)), ('d', tf.constant(0.0))]
                ),
            ),
        ]),
        'test_name',
    )
    self.assertEqual(names, ['test_name/a', 'test_name/b/c', 'test_name/b/d'])

  def test_variable_names_from_structure_with_named_tuple_type_no_name_field(
      self,
  ):
    names = checkpoint_utils.variable_names_from_structure(
        tff.structure.Struct([
            (None, tf.constant(1.0)),
            (
                'b',
                tff.structure.Struct(
                    [(None, tf.constant(False)), ('d', tf.constant(0.0))]
                ),
            ),
        ]),
        'test_name',
    )
    self.assertEqual(names, ['test_name/0', 'test_name/b/0', 'test_name/b/d'])

  def test_save_tf_tensor_to_checkpoint_as_expected(self):
    temp_dir = self.create_tempdir()
    output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')

    tensor = tf.constant([[1.0, 2.0], [3.0, 4.0]])

    checkpoint_utils.save_tff_structure_to_checkpoint(
        tensor, ['v'], output_checkpoint_path=output_checkpoint_path
    )

    reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    self.assertLen(var_to_shape_map, 1)
    self.assertIn('v', var_to_shape_map)
    np.testing.assert_almost_equal(
        [[1.0, 2.0], [3.0, 4.0]], reader.get_tensor('v')
    )

  def test_save_tff_struct_to_checkpoint_as_expected(self):
    temp_dir = self.create_tempdir()
    output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')

    struct = tff.structure.Struct([
        ('foo', tf.constant(1, dtype=tf.int32)),
        ('bar', tf.constant('bla', dtype=tf.string)),
    ])

    checkpoint_utils.save_tff_structure_to_checkpoint(
        struct,
        ordered_var_names=['v/foo', 'v/bar'],
        output_checkpoint_path=output_checkpoint_path,
    )

    reader = tf.compat.v1.train.NewCheckpointReader(output_checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    self.assertLen(var_to_shape_map, 2)
    self.assertIn('v/foo', var_to_shape_map)
    self.assertIn('v/bar', var_to_shape_map)
    self.assertEqual(1, reader.get_tensor('v/foo'))
    self.assertEqual(b'bla', reader.get_tensor('v/bar'))

  def test_save_tff_struct_to_checkpoint_fails_if_wrong_num_var_names(self):
    temp_dir = self.create_tempdir()
    output_checkpoint_path = os.path.join(temp_dir, 'output_checkpoint.ckpt')

    struct = tff.structure.Struct([
        ('foo', tf.constant(1, dtype=tf.int32)),
        ('bar', tf.constant('bla', dtype=tf.string)),
    ])

    with self.assertRaisesRegex(ValueError, 'does not match the number'):
      checkpoint_utils.save_tff_structure_to_checkpoint(
          struct,
          ordered_var_names=['v/foo'],
          output_checkpoint_path=output_checkpoint_path,
      )

  @parameterized.named_parameters(
      ('tf.tensor', tf.constant(1.0)),
      ('ndarray', np.asarray([1.0, 2.0, 3.0])),
      ('npnumber', np.float64(1.0)),
      ('int', 1),
      ('float', 1.0),
      ('str', 'test'),
      ('bytes', b'test'),
  )
  def test_is_allowed(self, structure):
    self.assertTrue(checkpoint_utils.is_structure_of_allowed_types(structure))

  @parameterized.named_parameters(
      ('function', lambda x: x),
      ('any_proto', any_pb2.Any()),
  )
  def test_is_not_allowed(self, structure):
    self.assertFalse(checkpoint_utils.is_structure_of_allowed_types(structure))


class CreateDeterministicSaverTest(tf.test.TestCase):

  def test_failure_unknown_type(self):
    with self.assertRaisesRegex(ValueError, 'Do not know how to make'):
      # Using a cast in case the test is being run with static type checking.
      checkpoint_utils.create_deterministic_saver(
          typing.cast(list[tf.Variable], 0)
      )

  def test_creates_saver_for_list(self):
    with tf.Graph().as_default() as g:
      saver = checkpoint_utils.create_deterministic_saver([
          tf.Variable(initial_value=1.0, name='z'),
          tf.Variable(initial_value=2.0, name='x'),
          tf.Variable(initial_value=3.0, name='y'),
      ])
    self.assertIsInstance(saver, tf.compat.v1.train.Saver)
    test_filepath = self.create_tempfile().full_path
    with tf.compat.v1.Session(graph=g) as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      saver.save(sess, save_path=test_filepath)
    variable_specs = tf.train.list_variables(test_filepath)
    self.assertEqual([('x', []), ('y', []), ('z', [])], variable_specs)

  def test_creates_saver_for_dict(self):
    with tf.Graph().as_default() as g:
      saver = checkpoint_utils.create_deterministic_saver({
          'foo': tf.Variable(initial_value=1.0, name='z'),
          'baz': tf.Variable(initial_value=2.0, name='x'),
          'bar': tf.Variable(initial_value=3.0, name='y'),
      })
    self.assertIsInstance(saver, tf.compat.v1.train.Saver)
    test_filepath = self.create_tempfile().full_path
    with tf.compat.v1.Session(graph=g) as sess:
      sess.run(tf.compat.v1.global_variables_initializer())
      saver.save(sess, save_path=test_filepath)
    variable_specs = tf.train.list_variables(test_filepath)
    self.assertEqual([('bar', []), ('baz', []), ('foo', [])], variable_specs)


if __name__ == '__main__':
  absltest.main()
