# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class to specify on-device dataset inputs."""

from collections.abc import Callable
from typing import Any, Optional, Union

import tensorflow as tf
import tensorflow_federated as tff

from fcp.artifact_building import type_checks
from fcp.protos import plan_pb2


class DataSpec:
  """A specification of a single dataset input."""

  __slots__ = (
      '_example_selector_proto',
      '_preprocessing_fn',
      '_preprocessing_comp',
      '_fingerprint',
  )

  def __init__(
      self,
      example_selector_proto: plan_pb2.ExampleSelector,
      preprocessing_fn: Optional[
          Callable[[tf.data.Dataset], tf.data.Dataset]
      ] = None,
  ):
    """Constructs a specification of a dataset input.

    Args:
      example_selector_proto: An instance of `plan_pb2.ExampleSelector` proto.
      preprocessing_fn: A callable that accepts as an argument the raw input
        `tf.data.Dataset` with `string`-serialized items, performs any desired
        preprocessing such as deserialization, filtering, batching, and
        formatting, and returns the transformed `tf.data.Dataset` as a result.
        If preprocessing_fn is set to None, it is expected that any client data
        preprocessing has already been incorporated into the `tff.Computation`
        that this `DataSpec` is associated with.

    Raises:
      TypeError: If the types of the arguments are invalid.
    """
    type_checks.check_type(
        example_selector_proto,
        plan_pb2.ExampleSelector,
        name='example_selector_proto',
    )
    if preprocessing_fn is not None:
      type_checks.check_callable(preprocessing_fn, name='preprocessing_fn')
    self._example_selector_proto = example_selector_proto
    self._preprocessing_fn = preprocessing_fn
    # Set once self.preprocessing_comp is accessed, as we can't call
    # tff.computation in __init__.
    self._preprocessing_comp = None

  @property
  def example_selector_proto(self) -> plan_pb2.ExampleSelector:
    return self._example_selector_proto

  @property
  def preprocessing_fn(
      self,
  ) -> Optional[Callable[[tf.data.Dataset], tf.data.Dataset]]:
    return self._preprocessing_fn

  @property
  def preprocessing_comp(self) -> tff.Computation:
    """Returns the preprocessing computation for the input dataset."""
    if self._preprocessing_comp is None:
      if self.preprocessing_fn is None:
        raise ValueError(
            "DataSpec's preprocessing_fn is None so a "
            'preprocessing tff.Computation cannot be generated.'
        )
      self._preprocessing_comp = tff.tf_computation(
          self.preprocessing_fn, tff.SequenceType(tf.string)
      )
    return self._preprocessing_comp

  @property
  def type_signature(self) -> tff.Type:
    """Returns the type signature of the result of the preprocessing_comp.

    Effectively the type or 'spec' of the parsed example from the example store
    pointed at by `example_selector_proto`.
    """
    return self.preprocessing_comp.type_signature.result


def is_data_spec_or_structure(x: Any) -> bool:
  """Returns True iff `x` is either a `DataSpec` or a nested structure of it."""
  if x is None:
    return False
  if isinstance(x, DataSpec):
    return True
  try:
    x = tff.structure.from_container(x)
    return all(
        is_data_spec_or_structure(y) for _, y in tff.structure.to_elements(x)
    )
  except TypeError:
    return False


def check_data_spec_or_structure(x: Any, name: str):
  """Raises error iff `x` is not a `DataSpec` or a nested structure of it."""
  if not is_data_spec_or_structure(x):
    raise TypeError(
        f'Expected `{name}` to be a `DataSpec` or a nested '
        f'structure of it, found {str(x)}.'
    )


NestedDataSpec = Union[DataSpec, dict[str, 'NestedDataSpec']]


def generate_example_selector_bytes_list(ds: NestedDataSpec):
  """Returns an ordered list of the bytes of each DataSpec's example selector.

  The order aligns with the order of a struct given by
  tff.structure.to_elements().

  Args:
    ds: A `NestedDataSpec`.
  """
  if isinstance(ds, DataSpec):
    return [ds.example_selector_proto.SerializeToString()]
  else:
    ds = tff.structure.from_container(ds)
    assert isinstance(ds, tff.structure.Struct)
    data_spec_elements = tff.structure.to_elements(ds)
    selector_bytes_list = []
    for _, element in data_spec_elements:
      selector_bytes_list.extend(generate_example_selector_bytes_list(element))
    return selector_bytes_list
