tff.learning.algorithms.build_weighted_fed_avg | TensorFlow Federated (original) (raw)
Builds a learning process that performs federated averaging.
tff.learning.algorithms.build_weighted_fed_avg(
model_fn: Union[Callable[[], tff.learning.models.VariableModel], tff.learning.models.FunctionalModel],
client_optimizer_fn: tff.learning.optimizers.Optimizer,
server_optimizer_fn: Optional[tff.learning.optimizers.Optimizer] = None,
*,
client_weighting: Optional[tff.learning.ClientWeighting] = tff.learning.ClientWeighting.NUM_EXAMPLES,
model_distributor: Optional[tff.learning.templates.DistributionProcess] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType] = None,
loop_implementation: tff.learning.LoopImplementation = tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
Used in the notebooks
This function creates a tff.learning.templates.LearningProcess that performs federated averaging on client models. The iterative process has the following methods inherited from tff.learning.templates.LearningProcess:
initialize
: A tff.Computation with the functional type signature( -> S@SERVER)
, whereS
is atff.learning.templates.LearningAlgorithmState representing the initial state of the server.next
: A tff.Computation with the functional type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
whereS
is atff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
and{B*}@CLIENTS
represents the client datasets. The outputL
contains the updated server state, as well as aggregated metrics at the server, including client training metrics and any other metrics from distribution and aggregation processes.get_model_weights
: A tff.Computation with type signature(S -> M)
, whereS
is a tff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
andnext
, andM
represents the type of the model weights used during training.set_model_weights
: A tff.Computation with type signature(<S, M> -> S)
, whereS
is atff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
andM
represents the type of the model weights used during training.
Each time the next
method is called, the server model is communicated to each client using the provided model_distributor
. For each client, local training is performed using client_optimizer_fn
. Each client computes the difference between the client model after training and its initial model. These model deltas are then aggregated at the server using a weighted aggregation function, according to client_weighting
. The aggregate model delta is applied at the server using a server optimizer.
Args | |
---|---|
model_fn | A no-arg function that returns atff.learning.models.VariableModel, or an instance of atff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. |
client_optimizer_fn | A tff.learning.optimizers.Optimizer. |
server_optimizer_fn | A tff.learning.optimizers.Optimizer. By default, this uses tff.leanring.optimizers.build_sgdm with a learning rate of 1.0. |
client_weighting | A member of tff.learning.ClientWeighting that specifies a built-in weighting method. By default, weighting by number of examples is used. |
model_distributor | An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed viatff.learning.templates.build_broadcast_process. |
model_aggregator | An optional tff.aggregators.WeightedAggregationFactoryused to aggregate client updates on the server. If None, this is set totff.aggregators.MeanFactory. |
metrics_aggregator | A function that takes in the metric finalizers (i.e.,tff.learning.models.VariableModel.metric_finalizers()) and atff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type oftff.learning.models.VariableModel.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. If None, this is set to tff.learning.metrics.sum_then_finalize. |
loop_implementation | Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details. |
Returns |
---|
A tff.learning.templates.LearningProcess. |
Raises | |
---|---|
TypeError | If arguments are not the documented types. |