tff.learning.algorithms.build_fed_sgd | TensorFlow Federated (original) (raw)
tff.learning.algorithms.build_fed_sgd
Stay organized with collections Save and categorize content based on your preferences.
Builds a learning process that performs federated SGD.
tff.learning.algorithms.build_fed_sgd(
model_fn: Union[Callable[[], tff.learning.models.VariableModel], tff.learning.models.FunctionalModel],
server_optimizer_fn: tff.learning.optimizers.Optimizer = DEFAULT_SERVER_OPTIMIZER_FN,
model_distributor: Optional[tff.learning.templates.DistributionProcess] = None,
model_aggregator: Optional[tff.aggregators.WeightedAggregationFactory] = None,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType] = None,
loop_implementation: tff.learning.LoopImplementation = tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
This function creates a tff.learning.templates.LearningProcess that performs federated SGD on client models. The learning process has the following methods inherited from tff.learning.templates.LearningProcess:
initialize
: A tff.Computation with type signature( -> S@SERVER)
, whereS
is a tff.learning.templates.LearningAlgorithmStaterepresenting the initial state of the server.next
: A tff.Computation with type signature(<S@SERVER, {B*}@CLIENTS> -> <S@SERVER, T@SERVER>)
whereS
is aLearningAlgorithmState
whose type matches that of the output ofinitialize
, and{B*}@CLIENTS
represents the client datasets, whereB
is the type of a single batch. This computation returns aLearningAlgorithmState
representing the updated server state and the metrics during client training and any other metrics from broadcast and aggregation processes.get_model_weights
: A tff.Computation with type signature(S -> M)
, whereS
is a tff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
andnext
, andM
represents the type of the model weights used during training.set_model_weights
: A tff.Computation with type signature(<S, M> -> S)
, whereS
is atff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
andM
represents the type of the model weights used during training.
Each time next
is called, the server model is broadcast to each client using a distributor. Each client sums the gradients for each batch in its local dataset (without updating its model) to calculate, and averages the gradients based on their number of examples. These average gradients are then aggregated at the server, and are applied at the server using an optimizer.
This implements the original FedSGD algorithm in McMahan et al., 2017.
Args | |
---|---|
model_fn | A no-arg function that returns atff.learning.models.VariableModel, or an instance of atff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. |
server_optimizer_fn | A tff.learning.optimizers.Optimizer used to apply client updates to the server model. |
model_distributor | An optional DistributionProcess that distributes the model weights on the server to the clients. If set to None, the distributor is constructed via distributors.build_broadcast_process. |
model_aggregator | An optional tff.aggregators.WeightedAggregationFactoryused to aggregate client updates on the server. If None, this is set totff.aggregators.MeanFactory. |
metrics_aggregator | A function that takes in the metric finalizers (i.e.,tff.learning.models.VariableModel.metric_finalizers()) and atff.types.StructWithPythonType of the unfinalized metrics (i.e., the TFF type oftff.learning.models.VariableModel.report_local_unfinalized_metrics()), and returns a tff.Computation for aggregating the unfinalized metrics. |
loop_implementation | Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details. |
Returns |
---|
A tff.learning.templates.LearningProcess. |