// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package google.internal.federated.plan; import "google/protobuf/any.proto"; import "tensorflow/core/framework/tensor.proto"; import "tensorflow/core/framework/types.proto"; import "tensorflow/core/protobuf/saver.proto"; import "tensorflow/core/protobuf/struct.proto"; option java_package = "com.google.internal.federated.plan"; option java_multiple_files = true; option java_outer_classname = "PlanProto"; // Primitives // =========== // Represents an operation to save or restore from a checkpoint. Some // instances of this message may only be used either for restore or for // save, others for both directions. This is documented together with // their usage. // // This op has four essential uses: // 1. read and apply a checkpoint. // 2. write a checkpoint. // 3. read and apply from an aggregated side channel. // 4. write to a side channel (grouped with write a checkpoint). // We should consider splitting this into four separate messages. message CheckpointOp { // An optional standard saver def. If not provided, only the // op(s) below will be executed. This must be a version 1 SaverDef. tensorflow.SaverDef saver_def = 1; // An optional operation to run before the saver_def is executed for // restore. string before_restore_op = 2; // An optional operation to run after the saver_def has been // executed for restore. If side_channel_tensors are provided, then // they should be provided in a feed_dict to this op. string after_restore_op = 3; // An optional operation to run before the saver_def will be // executed for save. string before_save_op = 4; // An optional operation to run after the saver_def has been // executed for save. If there are side_channel_tensors, this op // should be run after the side_channel_tensors have been fetched. string after_save_op = 5; // In addition to being saved and restored from a checkpoint, one can // also save and restore via a side channel. The keys in this map are // the names of the tensors transmitted by the side channel. These (key) // tensors should be read off just before saving a SaveDef and used // by the code that handles the side channel. Any variables provided this // way should NOT be saved in the SaveDef. // // For restoring, the variables that are provided by the side channel // are restored differently than those for a checkpoint. For those from // the side channel, these should be restored by calling the before_restore_op // with a feed dict whose keys are the restore_names in the SideChannel and // whose values are the values to be restored. map side_channel_tensors = 6; // An optional name of a tensor in to which a unique token for the current // session should be written. // // This session identifier allows TensorFlow ops such as `ServeSlices` or // `ExternalDataset` to refer to callbacks and other session-global objects // registered before running the session. string session_token_tensor_name = 7; } message SideChannel { // A side channel whose variables are processed via SecureAggregation. // This side channel implements aggregation via sum over a set of // clients, so the restored tensor will be a sum of multiple clients // inputs into the side channel. Hence this will restore during the // read_aggregate_update restore, not the per-client read_update restore. message SecureAggregand { message Dimension { int64 size = 1; } // Dimensions of the aggregand. This is used by the secure aggregation // protocol in its early rounds, not as redundant info which could be // obtained by reading the dimensions of the tensor itself. repeated Dimension dimension = 3; // The data type anticipated by the server-side graph. tensorflow.DataType dtype = 4; // SecureAggregation will compute sum modulo this modulus. message FixedModulus { uint64 modulus = 1; } // SecureAggregation will for each shard compute sum modulo m with m at // least (1 + shard_size * (base_modulus - 1)), then aggregate // shard results with non-modular addition. Here, shard_size is the number // of clients in the shard. // // Note that the modulus for each shard will be greater than the largest // possible (non-modular) sum of the inputs to that shard. That is, // assuming each client has input on range [0, base_modulus), the result // will be identical to non-modular addition (i.e. federated_sum). // // While any m >= (1 + shard_size * (base_modulus - 1)), the current // implementation takes // m = 2**ceil(log_2(1 + shard_size * (base_modulus - 1))), which is the // smallest possible value of m that is also a power of 2. This choice is // made because (a) it uses the same number of bits per vector entry as // valid smaller m, using the current on-the-wire encoding scheme, and (b) // it enables the underlying mask-generation PRNG to run in its most // computationally efficient mode, which can be up to 2x faster. message ModulusTimesShardSize { uint64 base_modulus = 1; } oneof modulus_scheme { // Bitwidth of the aggregand. // // This is the bitwidth of an input value (i.e. the bitwidth that // quantization should target). The Secure Aggregation bitwidth (i.e., // the bitwidth of the *sum* of the input values) will be a function of // this bitwidth and the number of participating clients, as negotiated // with the server when the protocol is initiated. // // Deprecated; prefer fixed_modulus instead. int32 quantized_input_bitwidth = 2 [deprecated = true]; FixedModulus fixed_modulus = 5; ModulusTimesShardSize modulus_times_shard_size = 6; } reserved 1; } // What type of side channel is used. oneof type { SecureAggregand secure_aggregand = 1; } // When restoring the name of the tensor to restore to. This is the name // (key) supplied in the feed_dict in the before_restore_op in order to // restore the tensor provided by the side channel (which will be the // value in the feed_dict). string restore_name = 2; } // Container for a metric used by the internal toolkit. message Metric { // Name of an Op to run to read the value. string variable_name = 1; // A human-readable name for the statistic. Metric names are usually // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. // Must be 7-bit ASCII and under 122 characters. string stat_name = 2; // The human-readable name of another metric by which this metric should be // normalized, if any. If empty, this Metric should be aggregated with simple // summation. If not empty, the Metric is aggregated according to // weighted_metric_sum = sum_i (metric_i * weight_i) // weight_sum = sum_i weight_i // average_metric_value = weighted_metric_sum / weight_sum string weight_name = 3; } // Controls the format of output metrics users receive. Represents instructions // for how metrics are to be output to users, controlling the end format of // the metric users receive. message OutputMetric { // Metric name. string name = 1; oneof value_source { // A metric representing one stat with aggregation type sum. SumOptions sum = 2; // A metric representing a ratio between metrics with aggregation // type average. AverageOptions average = 3; // A metric that is not aggregated by the MetricReportAggregator or // metrics_loader. This includes metrics like 'num_server_updates' that are // aggregated in TensorFlow. NoneOptions none = 4; // A metric representing one stat with aggregation type only sample. // Samples at most 101 clients' values. OnlySampleOptions only_sample = 5; } // Iff True, the metric will be plotted in the default view of the // task level Colab automatically. oneof visualization_info { bool auto_plot = 6 [deprecated = true]; VisualizationSpec plot_spec = 7; } } message VisualizationSpec { // Different allowable plot types. enum VisualizationType { NONE = 0; DEFAULT_PLOT_FOR_TASK_TYPE = 1; LINE_PLOT = 2; LINE_PLOT_WITH_PERCENTILES = 3; HISTOGRAM = 4; } // Defines the plot type to provide downstream. VisualizationType plot_type = 1; // The x-axis which to provide for the given metric. Must be the name of a // metric or counter. Recommended x_axis options are source_round, round, // or time. string x_axis = 2; // Iff True, metric will be displayed on a population level dashboard. bool plot_on_population_dashboard = 3; } // A metric representing one stat with aggregation type sum. message SumOptions { // Name for corresponding Metric stat_name field. string stat_name = 1; // Iff True, a cumulative sum over rounds will be provided in addition to a // sum per round for the value metric. bool include_cumulative_sum = 2; // Iff True, sample of at most 101 clients' values. // Used to calculate quantiles in downstream visualization pipeline. bool include_client_samples = 3; } // A metric representing a ratio between metrics with aggregation type average. // Represents: numerator stat / denominator stat. message AverageOptions { // Numerator stat name pointing to corresponding Metric stat_name. string numerator_stat_name = 1; // Denominator stat name pointing to corresponding Metric stat_name. string denominator_stat_name = 2; // Name for corresponding Metric stat_name that is the ratio of the // numerator stat / denominator stat. string average_stat_name = 3; // Iff True, sample of at most 101 client's values. // Used to calculate quantiles in downstream visualization pipeline. bool include_client_samples = 4; } // A metric representing one stat with aggregation type none. message NoneOptions { // Name for corresponding Metric stat_name field. string stat_name = 1; } // A metric representing one stat with aggregation type only sample. message OnlySampleOptions { // Name for corresponding Metric stat_name field. string stat_name = 1; } // Represents a data set. This is used for testing. message Dataset { // Represents the data set for one client. message ClientDataset { // A string identifying the client. string client_id = 1; // A list of serialized tf.Example protos. repeated bytes example = 2; // Represents a dataset whose examples are selected by an ExampleSelector. message SelectedExample { ExampleSelector selector = 1; repeated bytes example = 2; } // A list of (selector, dataset) pairs. Used in testing some *TFF-based // tasks* that require multiple datasets as client input, e.g., a TFF-based // personalization eval task requires each client to provide at least two // datasets: one for train, and the other for test. repeated SelectedExample selected_example = 3; } // A list of client data. repeated ClientDataset client_data = 1; } // Represents predicates over metrics - i.e., expectations. This is used in // training/eval tests to encode metric names and values expected to be reported // by a client execution. message MetricTestPredicates { // The value must lie in [lower_bound; upper_bound]. Can also be used for // approximate matching (lower == value - epsilon; upper = value + epsilon). message Interval { double lower_bound = 1; double upper_bound = 2; } // The value must be a real value as long as the value of the weight_name // metric is non-zero. If the weight metric is zero, then it is acceptable for // the value to be non-real. message RealIfNonzeroWeight { string weight_name = 1; } message MetricCriterion { // Name of the metric. string name = 1; // FL training round this metric is expected to appear in. int32 training_round_index = 2; // If none of the following is set, no matching is performed; but the // metric is still expected to be present (with whatever value). oneof Criterion { // The reported metric must be < lt. float lt = 3; // The reported metric must be > gt. float gt = 4; // The reported metric must be <= le. float le = 5; // The reported metric must be >= ge. float ge = 6; // The reported metric must be == eq. float eq = 7; // The reported metric must lie in the interval. Interval interval = 8; // The reported metric is not NaN or +/- infinity. bool real = 9; // The reported metric is real (i.e., not NaN or +/- infinity) if the // value of an associated weight is not 0. RealIfNonzeroWeight real_if_nonzero_weight = 10; } } repeated MetricCriterion metric_criterion = 1; reserved 2; } // Client Phase // ============ // A `TensorflowSpec` that is executed on the client in a single `tf.Session`. // In federated optimization, this will correspond to one `ServerPhase`. message ClientPhase { // A short CamelCase name for the ClientPhase. string name = 2; // Minimum number of clients in aggregation. // In secure aggregation mode this is used to configure the protocol instance // in a way that server can't learn aggregated values with number of // participants lower than this number. // Without secure aggregation server still respects this parameter, // ensuring that aggregated values never leave server RAM unless they include // data from (at least) specified number of participants. int32 minimum_number_of_participants = 3; // If populated, `io_router` must be specified. oneof spec { // A functional interface for the TensorFlow logic the client should // perform. TensorflowSpec tensorflow_spec = 4 [lazy = true]; // Spec for client plans that issue example queries and send the query // results directly to an aggregator with no or little additional // processing. ExampleQuerySpec example_query_spec = 9 [lazy = true]; } // The specification of the inputs coming either from customer apps // (Local Compute) or the federated protocol (Federated Compute). oneof io_router { FederatedComputeIORouter federated_compute = 5 [lazy = true]; LocalComputeIORouter local_compute = 6 [lazy = true]; FederatedComputeEligibilityIORouter federated_compute_eligibility = 7 [lazy = true]; FederatedExampleQueryIORouter federated_example_query = 8 [lazy = true]; } reserved 1; } // TensorflowSpec message describes a single call into TensorFlow, including the // expected input tensors that must be fed when making that call, which // output tensors to be fetched, and any operations that have no output but must // be run. The TensorFlow session will then use the input tensors to do some // computation, generally reading from one or more datasets, and provide some // outputs. // // Conceptually, client or server code uses this proto along with an IORouter // to build maps of names to input tensors, vectors of output tensor names, // and vectors of target nodes: // // CreateTensorflowArguments( // TensorflowSpec& spec, // IORouter& io_router, // const vector>* input_tensors, // const vector* output_tensor_names, // const vector* target_node_names); // // Where `input_tensor`, `output_tensor_names` and `target_node_names` // correspond to the arguments of TensorFlow C++ API for // `tensorflow::Session:Run()`, and the client executes only a single // invocation. // // Note: the execution engine never sees any concepts related to the federated // protocol, e.g. input checkpoints or aggregation protocols. This is a "tensors // in, tensors out" interface. New aggregation methods can be added without // having to modify the execution engine / TensorflowSpec message, instead they // should modify the IORouter messages. // // Note: both `input_tensor_specs` and `output_tensor_specs` are full // `tensorflow.TensorSpecProto` messages, though TensorFlow technically // only requires the names to feed the values into the session. The additional // dtypes/shape information must always be included in case the runtime // executing this TensorflowSpec wants to perform additional, optional static // assertions. The runtimes however are free to ignore the dtype/shapes and only // rely on the names if so desired. // // Assertions: // - all names in `input_tensor_specs`, `output_tensor_specs`, and // `target_node_names` must appear in the serialized GraphDef where // the TF execution will be invoked. // - `output_tensor_specs` or `target_node_names` must be non-empty, otherwise // there is nothing to execute in the graph. message TensorflowSpec { // The name of a tensor into which a unique token for the current session // should be written. The corresponding tensor is a scalar string tensor and // is separate from `input_tensors` as there is only one. // // A session token allows TensorFlow ops such as `ServeSlices` or // `ExternalDataset` to refer to callbacks and other session-global objects // registered before running the session. In the `ExternalDataset` case, a // single dataset_token is valid for multiple `tf.data.Dataset` objects as // the token can be thought of as a handle to a dataset factory. string dataset_token_tensor_name = 1; // TensorSpecs of inputs which will be passed to TF. // // Corresponds to the `feed_dict` parameter of `tf.Session.run()` in // TensorFlow's Python API, excluding the dataset_token listed above. // // Assertions: // - All the tensor names designated as inputs in the corresponding IORouter // must be listed (otherwise the IORouter input work is unused). // - All placeholders in the TF graph must be listed here, with the // exception of the dataset_token which is explicitly set above (otherwise // TensorFlow will fail to execute). repeated tensorflow.TensorSpecProto input_tensor_specs = 2; // TensorSpecs that should be fetched from TF after execution. // // Corresponds to the `fetches` parameter of `tf.Session.run()` in // TensorFlow's Python API, and the `output_tensor_names` in TensorFlow's C++ // API. // // Assertions: // - The set of tensor names here must strictly match the tensor names // designated as outputs in the corresponding IORouter (if any exist). repeated tensorflow.TensorSpecProto output_tensor_specs = 3; // Node names in the graph that should be executed, but the output not // returned. // // Corresponds to the `fetches` parameter of `tf.Session.run()` in // TensorFlow's Python API, and the `target_node_names` in TensorFlow's C++ // API. // // This is intended for use with operations that do not produce tensors, but // nonetheless are required to run (e.g. serializing checkpoints). repeated string target_node_names = 4; // Map of Tensor names to constant inputs. // Note: tensors specified via this message should not be included in // input_tensor_specs. map constant_inputs = 5; // The fields below are added by OnDevicePersonalization module. // Specifies an example selection procedure. ExampleSelector example_selector = 999; } // ExampleQuerySpec message describes client execution that issues example // queries and sends the query results directly to an aggregator with no or // little additional processing. // This message describes one or more example store queries that perform the // client side analytics computation in C++. The corresponding output vectors // will be converted into the expected federated protocol output format. // This must be used in conjunction with the `FederatedExampleQueryIORouter`. message ExampleQuerySpec { message OutputVectorSpec { // The output vector name. string vector_name = 1; // Supported data types for the vector of information. enum DataType { UNSPECIFIED = 0; INT32 = 1; INT64 = 2; BOOL = 3; FLOAT = 4; DOUBLE = 5; BYTES = 6; STRING = 7; } // The data type for each entry in the vector. DataType data_type = 2; } message ExampleQuery { // The `ExampleSelector` to issue the query with. ExampleSelector example_selector = 1; // Indicates that the query returns vector data and must return a single // ExampleQueryResult result containing a VectorData entry matching each // OutputVectorSpec.vector_name. // // If the query instead returns no result, then it will be treated as is if // an error was returned. In that case, or if the query explicitly returns // an error, then the client will abort its session. // // The keys in the map are the names the vectors should be aggregated under, // and must match the keys in FederatedExampleQueryIORouter.aggregations. map output_vector_specs = 2; } // The queries to run. repeated ExampleQuery example_queries = 1; } // The input and output router for Federated Compute plans. // // This proto is the glue between the federated protocol and the TensorFlow // execution engine. This message describes how to prepare data coming from the // incoming `CheckinResponse` (defined in // fcp/protos/federated_api.proto) for the `TensorflowSpec`, and what // to do with outputs from `TensorflowSpec` (e.g. how to aggregate them back on // the server). // // TODO(team) we could replace `input_checkpoint_file_tensor_name` with // an `input_tensors` field, which would then be a tensor that contains the // input TensorProtos directly and skipping disk I/O, rather than referring to a // checkpoint file path. message FederatedComputeIORouter { // =========================================================================== // Inputs // =========================================================================== // The name of the scalar string tensor that is fed the file path to the // initial checkpoint (e.g. as provided via AcceptanceInfo.init_checkpoint). // // The federated protocol code would copy the `CheckinResponse`'s initial // checkpoint to a temporary file and then pass that file path through this // tensor. // // Ops may be added to the client graph that take this tensor as input and // reads the path. // // This field is optional. It may be omitted if the client graph does not use // an initial checkpoint. string input_filepath_tensor_name = 1; // The name of the scalar string tensor that is fed the file path to which // client work should serialize the bytes to send back to the server. // // The federated protocol code generates a temporary file and passes the file // path through this tensor. // // Ops may be be added to the client graph that use this tensor as an argument // to write files (e.g. writing checkpoints to disk). // // This field is optional. It must be omitted if the client graph does not // generate any output files (e.g. when all output tensors of `TensorflowSpec` // use Secure Aggregation). If this field is not set, then the `ReportRequest` // message in the federated protocol will not have the // `Report.update_checkpoint` field set. This absence of a value here can be // used to validate that the plan only uses Secure Aggregation. // // Conversely, if this field is set and executing the associated // TensorflowSpec does not write to the path is indication of an internal // framework error. The runtime should notify the caller that the computation // was setup incorrectly. string output_filepath_tensor_name = 2; // =========================================================================== // Outputs // =========================================================================== // Describes which output tensors should be aggregated using an aggregation // protocol, and the configuration for those protocols. // // Assertions: // - All keys must exist in the associated `TensorflowSpec` as // `output_tensor_specs.name` values. map aggregations = 3; } // The input and output router for client plans that do not use TensorFlow. // // This proto is the glue between the federated protocol and the example query // execution engine, describing how the query results should ultimately be // aggregated. message FederatedExampleQueryIORouter { // Describes how each output vector should be aggregated using an aggregation // protocol, and the configuration for those protocols. // Keys must match the keys in ExampleQuerySpec.output_vector_specs. // Note that currently only the TFV1CheckpointAggregation config is supported. map aggregations = 1; } // The specification for how to aggregate the associated tensor across clients // on the server. message AggregationConfig { oneof protocol_config { // Indicates that the given output tensor should be processed using Secure // Aggregation, using the specified config options. SecureAggregationConfig secure_aggregation = 2; // Note: in the future we could add a `SimpleAggregationConfig` to add // support for simple aggregation without writing to an intermediate // checkpoint file first. // Indicates that the given output tensor or vector (e.g. as produced by an // ExampleQuerySpec) should be placed in an output TF v1 checkpoint. // // Currently only ExampleQuerySpec output vectors are supported by this // aggregation type (i.e. it cannot be used with TensorflowSpec output // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of // its corresponding data type. TFV1CheckpointAggregation tf_v1_checkpoint_aggregation = 3; } } // Parameters for the SecAgg protocol (go/secagg). // // Currently only the server uses the SecAgg parameters, so we only use this // message to signify usage of SecAgg. message SecureAggregationConfig {} // Parameters for the TFV1 Checkpoint Aggregation protocol. // // Currently only ExampleQuerySpec output vectors are supported by this // aggregation type (i.e. it cannot be used with TensorflowSpec output // tensors). The vectors will be stored in the checkpoint as a 1-D Tensor of // its corresponding data type. message TFV1CheckpointAggregation {} // The input and output router for eligibility-computing plans. These plans // compute which other plans a client is eligible to run, and are returned by // clients via a `EligibilityEvalCheckinResponse` (defined in // fcp/protos/federated_api.proto). message FederatedComputeEligibilityIORouter { // The name of the scalar string tensor that is fed the file path to the // initial checkpoint (e.g. as provided via // `EligibilityEvalPayload.init_checkpoint`). // // For more detail see the // `FederatedComputeIoRouter.input_filepath_tensor_name`, which has the same // semantics. // // This field is optional. It may be omitted if the client graph does not use // an initial checkpoint. // // This tensor name must exist in the associated // `TensorflowSpec.input_tensor_specs` list. string input_filepath_tensor_name = 1; // Name of the output tensor (a string scalar) containing the serialized // `google.internal.federatedml.v2.TaskEligibilityInfo` proto output. The // client code will parse this proto and place it in the // `task_eligibility_info` field of the subsequent `CheckinRequest`. // // This tensor name must exist in the associated // `TensorflowSpec.output_tensor_specs` list. string task_eligibility_info_tensor_name = 2; } // The input and output router for Local Compute plans. // // This proto is the glue between the customers app and the TensorFlow // execution engine. This message describes how to prepare data coming from the // customer app (e.g. the input directory the app setup), and the temporary, // scratch output directory that will be notified to the customer app upon // completion of `TensorflowSpec`. message LocalComputeIORouter { // =========================================================================== // Inputs // =========================================================================== // The name of the placeholder tensor representing the input resource path(s). // It can be a single input directory or file path (in this case the // `input_dir_tensor_name` is populated) or multiple input resources // represented as a map from names to input directories or file paths (in this // case the `multiple_input_resources` is populated). // // In the multiple input resources case, the placeholder tensors are // represented as a map: the keys are the input resource names defined by the // users when constructing the `LocalComputation` Python object, and the // values are the corresponding placeholder tensor names created by the local // computation plan builder. // // Apps will have the ability to create contracts between their Android code // and `LocalComputation` toolkit code to place files inside the input // resource paths with known names (Android code) and create graphs with ops // to read from these paths (file names can be specified in toolkit code). oneof input_resource { string input_dir_tensor_name = 1; // Directly using the `map` field is not allowed in `oneof`, so we have to // wrap it in a new message. MultipleInputResources multiple_input_resources = 3; } // Scalar string tensor name that will contain the output directory path. // // The provided directory should be considered temporary scratch that will be // deleted, not persisted. It is the responsibility of the calling app to // move the desired files to a permanent location once the client returns this // directory back to the calling app. string output_dir_tensor_name = 2; // =========================================================================== // Outputs // =========================================================================== // NOTE: LocalCompute has no outputs other than what the client graph writes // to `output_dir` specified above. } // Describes the multiple input resources in `LocalComputeIORouter`. message MultipleInputResources { // The keys are the input resource names (defined by the users when // constructing the `LocalComputation` Python object), and the values are the // corresponding placeholder tensor names created by the local computation // plan builder. map input_resource_tensor_name_map = 1; } // Describes a queue to which input is fed. message AsyncInputFeed { // The op for enqueuing an example input. string enqueue_op = 1; // The input placeholders for the enqueue op. repeated string enqueue_params = 2; // The op for closing the input queue. string close_op = 3; // Whether the work that should be fed asynchronously is the data itself // or a description of where that data lives. bool feed_values_are_data = 4; } message DatasetInput { // Initializer of iterator corresponding to tf.data.Dataset object which // handles the input data. Stores name of an op in the graph. string initializer = 1; // Placeholders necessary to initialize the dataset. DatasetInputPlaceholders placeholders = 2; // Batch size to be used in tf.data.Dataset. int32 batch_size = 3; } message DatasetInputPlaceholders { // Name of placeholder corresponding to filename(s) of SSTable(s) to read data // from. string filename = 1; // Name of placeholder corresponding to key_prefix initializing the // SSTableDataset. Note the value fed should be unique user id, not a prefix. string key_prefix = 2; // Name of placeholder corresponding to number of rounds the local training // should be run for. string num_epochs = 3; // Name of placeholder corresponding to batch size. string batch_size = 4; } // Specifies an example selection procedure. message ExampleSelector { // Selection criteria following a contract agreed upon between client and // model designers. google.protobuf.Any criteria = 1; // A URI identifying the example collection to read from. Format should adhere // to "${COLLECTION}://${APP_NAME}${COLLECTION_NAME}". The URI segments // should adhere to the following rules: // - The scheme ${COLLECTION} should be one of: // - "app" for app-hosted example // - "simulation" for collections not connected to an app (e.g., if used // purely for simulation) // - The authority ${APP_NAME} identifies the owner of the example // collection and should be either the app's package name, or be left empty // (which means "the current app package name"). // - The path ${COLLECTION_NAME} can be any valid URI path. NB It starts with // a forward slash ("/"). // - The query and fragment are currently not used, but they may become used // for something in the future. To keep open that possibility they must // currently be left empty. // // Example: "app://com.google.some.app/someCollection/name" // identifies the collection "/someCollection/name" owned and hosted by the // app with package name "com.google.some.app". // // Example: "app:/someCollection/name" or "app:///someCollection/name" // both identify the collection "/someCollection/name" owned and hosted by the // app associated with the training job in which this URI appears. // // The path will not be interpreted by the runtime, and will be passed to the // example collection implementation for interpretation. Thus, in the case of // app-hosted example stores, the path segment's interpretation is a contract // between the app's example store developers, and the app's model designers. // // If an `app://` URI is set, then the `TrainerOptions` collection name must // not be set. string collection_uri = 2; // Resumption token following a contract agreed upon between client and // model designers. google.protobuf.Any resumption_token = 3; } // Selector for slices to fetch as part of a `federated_select` operation. message SlicesSelector { // The string ID under which the slices are served. // // This value must have been returned by a previous call to the `serve_slices` // op run during the `write_client_init` operation. string served_at_id = 1; // The indices of slices to fetch. repeated int32 keys = 2; } // Represents slice data to be served as part of a `federated_select` operation. // This is used for testing. message SlicesTestDataset { // The test data to use. The keys map to the `SlicesSelector.served_at_id` // field. E.g. test slice data for a slice with `served_at_id`="foo" and // `keys`=2 would be store in `dataset["foo"].slice_data[2]`. map dataset = 1; } message SlicesTestData { // The test slice data to serve. Each entry's index corresponds to the slice // key it is the test data for. repeated bytes slice_data = 2; } // Server Phase V2 // =============== // Represents a server phase with three distinct components: pre-broadcast, // aggregation, and post-aggregation. // // The pre-broadcast and post-aggregation components are described with // the tensorflow_spec_prepare and tensorflow_spec_result TensorflowSpec // messages, respectively. These messages in combination with the server // IORouter messages specify how to set up a single TF sess.run call for each // component. // // The pre-broadcast logic is obtained by transforming the server_prepare TFF // computation in the DistributeAggregateForm. It takes the server state as // input, and it generates the checkpoint to broadcast to the clients and // potentially an intermediate server state. The intermediate server state may // be used by the aggregation and post-aggregation logic. // // The aggregation logic represents the aggregation of client results at the // server and is described using a list of ServerAggregationConfig messages. // Each ServerAggregationConfig message describes a single aggregation operation // on a set of input/output tensors. The input tensors may represent parts of // either the client results or the intermediate server state. These messages // are obtained by transforming the client_to_server_aggregation TFF computation // in the DistributeAggregateForm. // // The post-aggregation logic is obtained by transforming the server_result TFF // computation in the DistributeAggregateForm. It takes the intermediate server // state and the aggregated client results as input, and it generates the new // server state and potentially other server-side output. // // Note that while a ServerPhaseV2 message can be generated for all types of // intrinsics, it is currently only compatible with the ClientPhase message if // the aggregations being used are exclusively federated_sum (not SecAgg). If // this compatibility requirement is satisfied, it is also valid to run the // aggregation portion of this ServerPhaseV2 message alongside the pre- and // post-aggregation logic from the original ServerPhase message. Ultimately, // we expect the full ServerPhaseV2 message to be run and the ServerPhase // message to be deprecated. message ServerPhaseV2 { // A short CamelCase name for the ServerPhaseV2. string name = 1; // A functional interface for the TensorFlow logic the server should perform // prior to the server-to-client broadcast. This should be used with the // TensorFlow graph defined in server_graph_prepare_bytes. TensorflowSpec tensorflow_spec_prepare = 3; // The specification of inputs needed by the server_prepare TF logic. oneof server_prepare_io_router { ServerPrepareIORouter prepare_router = 4; } // A list of client-to-server aggregations to perform. repeated ServerAggregationConfig aggregations = 2; // A functional interface for the TensorFlow logic the server should perform // post-aggregation. This should be used with the TensorFlow graph defined // in server_graph_result_bytes. TensorflowSpec tensorflow_spec_result = 5; // The specification of inputs and outputs needed by the server_result TF // logic. oneof server_result_io_router { ServerResultIORouter result_router = 6; } } // Routing for server_prepare graph message ServerPrepareIORouter { // The name of the scalar string tensor in the server_prepare TF graph that // is fed the filepath to the initial server state checkpoint. The // server_prepare logic reads from this filepath. string prepare_server_state_input_filepath_tensor_name = 1; // The name of the scalar string tensor in the server_prepare TF graph that // is fed the filepath where the client checkpoint should be stored. The // server_prepare logic writes to this filepath. string prepare_output_filepath_tensor_name = 2; // The name of the scalar string tensor in the server_prepare TF graph that // is fed the filepath where the intermediate state checkpoint should be // stored. The server_prepare logic writes to this filepath. The intermediate // state checkpoint will be consumed by both the logic used to set parameters // for aggregation and the post-aggregation logic. string prepare_intermediate_state_output_filepath_tensor_name = 3; } // Routing for server_result graph message ServerResultIORouter { // The name of the scalar string tensor in the server_result TF graph that is // fed the filepath to the intermediate state checkpoint. The server_result // logic reads from this filepath. string result_intermediate_state_input_filepath_tensor_name = 1; // The name of the scalar string tensor in the server_result TF graph that is // fed the filepath to the aggregated client result checkpoint. The // server_result logic reads from this filepath. string result_aggregate_result_input_filepath_tensor_name = 2; // The name of the scalar string tensor in the server_result TF graph that is // fed the filepath where the updated server state should be stored. The // server_result logic writes to this filepath. string result_server_state_output_filepath_tensor_name = 3; } // Represents a single aggregation operation, combining one or more input // tensors from a collection of clients into one or more output tensors on the // server. message ServerAggregationConfig { // The uri of the aggregation intrinsic (e.g. 'federated_sum'). string intrinsic_uri = 1; // Describes an argument to the aggregation operation. message IntrinsicArg { oneof arg { // Refers to a tensor within the checkpoint provided by each client. tensorflow.TensorSpecProto input_tensor = 2; // Refers to a tensor within the intermediate server state checkpoint. tensorflow.TensorSpecProto state_tensor = 3; } } // List of arguments for the aggregation operation. The arguments can be // dependent on client data (in which case they must be retrieved from // clients) or they can be independent of client data (in which case they // can be configured server-side). For now we assume all client-independent // arguments are constants. The arguments must be in the order expected by // the server. repeated IntrinsicArg intrinsic_args = 4; // List of server-side outputs produced by the aggregation operation. repeated tensorflow.TensorSpecProto output_tensors = 5; // List of inner aggregation intrinsics. This can be used to delegate parts // of the aggregation logic (e.g. a groupby intrinsic may want to delegate // a sum operation to a sum intrinsic). repeated ServerAggregationConfig inner_aggregations = 6; } // Server Phase // ============ // Represents a server phase which implements TF-based aggregation of multiple // client updates. // // There are two different modes of aggregation that are described // by the values in this message. The first is aggregation that is // coming from coordinated sets of clients. This includes aggregation // done via checkpoints from clients or aggregation done over a set // of clients by a process like secure aggregation. The results of // this first aggregation are saved to intermediate aggregation // checkpoints. The second aggregation then comes from taking // these intermediate checkpoints and aggregating over them. // // These two different modes of aggregation are done on different // servers, the first in the 'L1' servers and the second in the // 'L2' servers, so we use this nomenclature to describe these // phases below. // // The ServerPhase message is currently in the process of being replaced by the // ServerPhaseV2 message as we switch the plan building pipeline to use // DistributeAggregateForm instead of MapReduceForm. During the migration // process, we may generate both messages and use components from either // message during execution. // message ServerPhase { // A short CamelCase name for the ServerPhase. string name = 8; // =========================================================================== // L1 "Intermediate" Aggregation. // // This is the initial aggregation that creates partial aggregates from client // results. L1 Aggregation may be run on many different instances. // // Pre-condition: // The execution environment has loaded the graph from `server_graph_bytes`. // 1. Initialize the phase. // // Operation to run before the first aggregation happens. // For instance, clears the accumulators so that a new aggregation can begin. string phase_init_op = 1; // 2. For each client in set of clients: // a. Restore variables from the client checkpoint. // // Loads a checkpoint from a single client written via // `FederatedComputeIORouter.output_filepath_tensor_name`. This is done once // for every client checkpoint in a round. CheckpointOp read_update = 3; // b. Aggregate the data coming from the client checkpoint. // // An operation that aggregates the data from read_update. // Generally this will add to accumulators and it may leverage internal data // inside the graph to adjust the weights of the Tensors. // // Executed once for each `read_update`, to (for example) update accumulator // variables using the values loaded during `read_update`. string aggregate_into_accumulators_op = 4; // 3. After all clients have been aggregated, possibly restore // variables that have been aggregated via a separate process. // // Optionally restores variables where aggregation is done across // an entire round of client data updates. In contrast to `read_update`, // which restores once per client, this occurs after all clients // in a round have been processed. This allows, for example, side // channels where aggregation is done by a separate process (such // as in secure aggregation), in which the side channel aggregated // tensor is passed to the `before_restore_op` which ensure the // variables are restored properly. The `after_restore_op` will then // be responsible for performing the accumulation. // // Note that in current use this should not have a SaverDef, but // should only be used for side channels. CheckpointOp read_aggregated_update = 10; // 4. Write the aggregated variables to an intermediate checkpoint. // // We require that `aggregate_into_accumulators_op` is associative and // commutative, so that the aggregates can be computed across // multiple TensorFlow sessions. // As an example, say we are computing the sum of 5 client updates: // A = X1 + X2 + X3 + X4 + X5 // We can always do this in one session by calling `read_update`j and // `aggregate_into_accumulators_op` once for each client checkpoint. // // Alternatively, we could compute: // A1 = X1 + X2 in one TensorFlow session, and // A2 = X3 + X4 + X5 in a different session. // Each of these sessions can then write their accumulator state // with the `write_intermediate_update` CheckpointOp, and a yet another third // session can then call `read_intermediate_update` and // `aggregate_into_accumulators_op` on each of these checkpoints to compute: // A = A1 + A2 = (X1 + X2) + (X3 + X4 + X5). CheckpointOp write_intermediate_update = 7; // End L1 "Intermediate" Aggregation. // =========================================================================== // =========================================================================== // L2 Aggregation and Coordinator. // // This aggregates intermediate checkpoints from L1 Aggregation and performs // the finalizing of the update. Unlike L1 there will only be one instance // that does this aggregation. // Pre-condition: // The execution environment has loaded the graph from `server_graph_bytes` // and restored the global model using `server_savepoint` from the parent // `Plan` message. // 1. Initialize the phase. // // This currently re-uses the `phase_init_op` from L1 aggregation above. // 2. Write a checkpoint that can be sent to the client. // // Generates a checkpoint to be sent to the client, to be read by // `FederatedComputeIORouter.input_filepath_tensor_name`. CheckpointOp write_client_init = 2; // 3. For each intermediate checkpoint: // a. Restore variables from the intermediate checkpoint. // // The corresponding read checkpoint op to the write_intermediate_update. // This is used instead of read_update for intermediate checkpoints because // the format of these updates may be different than those used in updates // from clients (which may, for example, be compressed). CheckpointOp read_intermediate_update = 9; // b. Aggregate the data coming from the intermediate checkpoint. // // An operation that aggregates the data from `read_intermediate_update`. // Generally this will add to accumulators and it may leverage internal data // inside the graph to adjust the weights of the Tensors. string intermediate_aggregate_into_accumulators_op = 11; // 4. Write the aggregated intermediate variables to a checkpoint. // // This is used for downstream, cross-round aggregation of metrics. // These variables will be read back into a session with // read_intermediate_update. // // Tasks which do not use FL metrics may unset the CheckpointOp.saver_def // to disable writing accumulator checkpoints. CheckpointOp write_accumulators = 12; // 5. Finalize the round. // // This can include: // - Applying the update aggregated from the intermediate checkpoints to the // global model and other updates to cross-round state variables. // - Computing final round metric values (e.g. the `report` of a // `tff.federated_aggregate`). string apply_aggregrated_updates_op = 5; // 5. Fetch the server aggregated metrics. // // A list of names of metric variables to fetch from the TensorFlow session. repeated Metric metrics = 6; // 6. Serialize the updated server state (e.g. the coefficients of the global // model in FL) using `server_savepoint` in the parent `Plan` message. // End L2 Aggregation. // =========================================================================== } // Represents the server phase in an eligibility computation. // // This phase produces a checkpoint to be sent to clients. This checkpoint is // then used as an input to the clients' task eligibility computations. // This phase *does not include any aggregation.* message ServerEligibilityComputationPhase { // A short CamelCase name for the ServerEligibilityComputationPhase. string name = 1; // The names of the TensorFlow nodes to run in order to produce output. repeated string target_node_names = 2; // The specification of inputs and outputs to the TensorFlow graph. oneof server_eligibility_io_router { TEContextServerEligibilityIORouter task_eligibility = 3 [lazy = true]; } } // Represents the inputs and outputs of a `ServerEligibilityComputationPhase` // which takes a single `TaskEligibilityContext` as input. message TEContextServerEligibilityIORouter { // The name of the scalar string tensor that must be fed a serialized // `TaskEligibilityContext`. string context_proto_input_tensor_name = 1; // The name of the scalar string tensor that must be fed the path to which // the server graph should write the checkpoint file to be sent to the client. string output_filepath_tensor_name = 2; } // Plan // ===== // Represents the overall plan for performing federated optimization or // personalization, as handed over to the production system. This will // typically be split down into individual pieces for different production // parts, e.g. server and client side. // NEXT_TAG: 15 message Plan { reserved 1, 3, 5; // The actual type of the server_*_graph_bytes fields below is expected to be // tensorflow.GraphDef. The TensorFlow graphs are stored in serialized form // for two reasons. // 1) We may use execution engines other than TensorFlow. // 2) We wish to avoid the cost of deserialized and re-serializing large // graphs, in the Federated Learning service. // While we migrate from ServerPhase to ServerPhaseV2, server_graph_bytes, // server_graph_prepare_bytes, and server_graph_result_bytes may all be set. // If we're using a MapReduceForm-based server implementation, only // server_graph_bytes will be used. If we're using a DistributeAggregateForm- // based server implementation, only server_graph_prepare_bytes and // server_graph_result_bytes will be used. // Optional. The TensorFlow graph used for all server processing described by // ServerPhase. For personalization, this will not be set. google.protobuf.Any server_graph_bytes = 7; // Optional. The TensorFlow graph used for all server processing described by // ServerPhaseV2.tensorflow_spec_prepare. google.protobuf.Any server_graph_prepare_bytes = 13; // Optional. The TensorFlow graph used for all server processing described by // ServerPhaseV2.tensorflow_spec_result. google.protobuf.Any server_graph_result_bytes = 14; // A savepoint to sync the server checkpoint with a persistent // storage system. The storage initially holds a seeded checkpoint // which can subsequently read and updated by this savepoint. // Optional-- not present in eligibility computation plans (those with a // ServerEligibilityComputationPhase). This is used in conjunction with // ServerPhase only. CheckpointOp server_savepoint = 2; // Required. The TensorFlow graph that describes the TensorFlow logic a client // should perform. It should be consistent with the `TensorflowSpec` field in // the `client_phase`. The actual type is expected to be tensorflow.GraphDef. // The TensorFlow graph is stored in serialized form for two reasons. // 1) We may use execution engines other than TensorFlow. // 2) We wish to avoid the cost of deserialized and re-serializing large // graphs, in the Federated Learning service. google.protobuf.Any client_graph_bytes = 8; // Optional. The FlatBuffer used for TFLite training. // It contains the same model information as the client_graph_bytes, but with // a different format. bytes client_tflite_graph_bytes = 12; // A pair of client phase and server phase which are processed in // sync. The server execution defines how the results of a client // phase are aggregated, and how the checkpoints for clients are // generated. message Phase { // Required. The client phase. ClientPhase client_phase = 1; // Optional. Server phase for TF-based aggregation; not provided for // personalization or eligibility tasks. ServerPhase server_phase = 2; // Optional. Server phase for native aggregation; only provided for tasks // that have enabled the corresponding flag. ServerPhaseV2 server_phase_v2 = 4; // Optional. Only provided for eligibility tasks. ServerEligibilityComputationPhase server_eligibility_phase = 3; } // A pair of client and server computations to run. repeated Phase phase = 4; // Metrics that are persistent across different phases. This // includes, for example, counters that track how much work of // different kinds has been done. repeated Metric metrics = 6; // Describes how metrics in both the client and server phases should be // aggregated. repeated OutputMetric output_metrics = 10; // Version of the plan: // version == 0 - Old plan without version field, containing b/65131070 // version >= 1 - plan supports multi-shard aggregation mode (L1/L2) int32 version = 9; // A TensorFlow ConfigProto packed in an Any. // // If this field is unset, if the Any proto is set but empty, or if the Any // proto is populated with an empty ConfigProto (i.e. its `type_url` field is // set, but the `value` field is empty) then the client implementation may // choose a set of configuration parameters to provide to TensorFlow by // default. // // In all other cases this field must contain a valid packed ConfigProto // (invalid values will result in an error at execution time), and in this // case the client will not provide any other configuration parameters by // default. google.protobuf.Any tensorflow_config_proto = 11; } // Represents a client part of the plan of federated optimization. // This also used to describe a client-only plan for standalone on-device // training, known as personalization. // NEXT_TAG: 6 message ClientOnlyPlan { reserved 3; // The graph to use for training, in binary form. bytes graph = 1; // Optional. The flatbuffer used for TFLite training. // Whether "graph" or "tflite_graph" is used for training is up to the client // code to allow for a flag-controlled a/b rollout. bytes tflite_graph = 5; // The client phase to execute. ClientPhase phase = 2; // A TensorFlow ConfigProto. google.protobuf.Any tensorflow_config_proto = 4; } // Represents the cross round aggregation portion for user defined measurements. // This is used by tools that process / analyze accumulator checkpoints // after a round of computation, to achieve aggregation beyond a round. message CrossRoundAggregationExecution { // Operation to run before reading accumulator checkpoint. string init_op = 1; // Reads accumulator checkpoint. CheckpointOp read_aggregated_update = 2; // Operation to merge loaded checkpoint into accumulator. string merge_op = 3; // Reads and writes the final aggregated accumulator vars. CheckpointOp read_write_final_accumulators = 6; // Metadata for mapping the TensorFlow `name` attribute of the `tf.Variable` // to the user defined name of the signal. repeated Measurement measurements = 4; // The `tf.Graph` used for aggregating accumulator checkpoints when // loading metrics. google.protobuf.Any cross_round_aggregation_graph_bytes = 5; } message Measurement { // Name of a TensorFlow op to run to read/fetch the value of this measurement. string read_op_name = 1; // A human-readable name for the measurement. Names are usually // camel case by convention, e.g., 'Loss', 'AbsLoss', or 'Accuracy'. string name = 2; reserved 3; // A serialized `tff.Type` for the measurement. bytes tff_type = 4; }