tff.learning.templates.compose_learning_process | TensorFlow Federated (original) (raw)
Composes specialized measured processes into a learning process.
tff.learning.templates.compose_learning_process(
initial_model_weights_fn: tff.Computation,
model_weights_distributor: tff.learning.templates.DistributionProcess,
client_work: tff.learning.templates.ClientWorkProcess,
model_update_aggregator: tff.templates.AggregationProcess,
model_finalizer: tff.learning.templates.FinalizerProcess
) -> tff.learning.templates.LearningProcess
Used in the notebooks
Used in the tutorials |
---|
Composing Learning Algorithms |
Given 4 specialized measured processes (described below) that make a learning process, and a computation that returns initial model weights to be used for training, this method validates that the processes fit together, and returns aLearningProcess
. Please see the tutorial athttps://www.tensorflow.org/federated/tutorials/composing_learning_algorithmsfor more details on composing learning processes.
The main purpose of the 4 measured processes are:
model_weights_distributor
: Make global model weights at server available as the starting point for learning work to be done at clients.client_work
: Produce an update to the model received at clients.model_update_aggregator
: Aggregates the model updates from clients to the server.model_finalizer
: Updates the global model weights using the aggregated model update at server.
The next
computation of the created learning process is composed from thenext
computations of the 4 measured processes, in order as visualized below. The type signatures of the processes must be such that this chaining is possible. Each process also reports its own metrics.
┌─────────────────────────┐
│model_weights_distributor│
└△─┬─┬────────────────────┘
│ │┌▽──────────┐
│ ││client_work│
│ │└┬─────┬────┘
│┌▽─▽────┐│
││metrics││
│└△─△────┘│
│ │┌┴─────▽────────────────┐
│ ││model_update_aggregator│
│ │└┬──────────────────────┘
┌┴─┴─▽──────────┐
│model_finalizer│
└┬──────────────┘
┌▽─────┐
│result│
└──────┘
The get_hparams
computation of the created learning process produces a nested ordered dictionary containing the result of client_work.get_hparams
and finalizer.get_hparams
. The set_hparams
computation operates similarly, by delegating to client_work.set_hparams
and finalizer.set_hparams
to set the hyperparameters in their associated states.
Args | |
---|---|
initial_model_weights_fn | A tff.Computation that returns (unplaced) initial model weights. |
model_weights_distributor | A tff.learning.templates.DistributionProcess. |
client_work | A tff.learning.templates.ClientWorkProcess. |
model_update_aggregator | A tff.templates.AggregationProcess. |
model_finalizer | A tff.learning.templates.FinalizerProcess. |
Returns |
---|
A tff.learning.templates.LearningProcess. |
Raises | |
---|---|
ClientSequenceTypeError | If the first arg of the next method of the resulting LearningProcess is not a structure of sequences placed attff.CLIENTS. |