tff.learning.templates.compose_learning_process  |  TensorFlow Federated (original) (raw)

Composes specialized measured processes into a learning process.

tff.learning.templates.compose_learning_process(
    initial_model_weights_fn: tff.Computation,
    model_weights_distributor: tff.learning.templates.DistributionProcess,
    client_work: tff.learning.templates.ClientWorkProcess,
    model_update_aggregator: tff.templates.AggregationProcess,
    model_finalizer: tff.learning.templates.FinalizerProcess
) -> tff.learning.templates.LearningProcess

Used in the notebooks

Used in the tutorials
Composing Learning Algorithms

Given 4 specialized measured processes (described below) that make a learning process, and a computation that returns initial model weights to be used for training, this method validates that the processes fit together, and returns aLearningProcess. Please see the tutorial athttps://www.tensorflow.org/federated/tutorials/composing_learning_algorithmsfor more details on composing learning processes.

The main purpose of the 4 measured processes are:

The next computation of the created learning process is composed from thenext computations of the 4 measured processes, in order as visualized below. The type signatures of the processes must be such that this chaining is possible. Each process also reports its own metrics.

┌─────────────────────────┐
│model_weights_distributor│
└△─┬─┬────────────────────┘
 │ │┌▽──────────┐
 │ ││client_work│
 │ │└┬─────┬────┘
 │┌▽─▽────┐│
 ││metrics││
 │└△─△────┘│
 │ │┌┴─────▽────────────────┐
 │ ││model_update_aggregator│
 │ │└┬──────────────────────┘
┌┴─┴─▽──────────┐
│model_finalizer│
└┬──────────────┘
┌▽─────┐
│result│
└──────┘

The get_hparams computation of the created learning process produces a nested ordered dictionary containing the result of client_work.get_hparamsand finalizer.get_hparams. The set_hparams computation operates similarly, by delegating to client_work.set_hparams and finalizer.set_hparams to set the hyperparameters in their associated states.

Args
initial_model_weights_fn A tff.Computation that returns (unplaced) initial model weights.
model_weights_distributor A tff.learning.templates.DistributionProcess.
client_work A tff.learning.templates.ClientWorkProcess.
model_update_aggregator A tff.templates.AggregationProcess.
model_finalizer A tff.learning.templates.FinalizerProcess.
Returns
A tff.learning.templates.LearningProcess.
Raises
ClientSequenceTypeError If the first arg of the next method of the resulting LearningProcess is not a structure of sequences placed attff.CLIENTS.