tff.learning.algorithms.build_fed_eval  |  TensorFlow Federated (original) (raw)

tff.learning.algorithms.build_fed_eval

Stay organized with collections Save and categorize content based on your preferences.

Builds a learning process that performs federated evaluation.

tff.learning.algorithms.build_fed_eval(
    model_fn: Union[Callable[[], tff.learning.models.VariableModel], tff.learning.models.FunctionalModel],
    model_distributor: Optional[tff.learning.templates.DistributionProcess] = None,
    metrics_aggregation_process: Optional[tff.templates.AggregationProcess] = None,
    loop_implementation: tff.learning.LoopImplementation = tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess

Used in the notebooks

Used in the tutorials
Federated Learning for Image Classification

This function creates a tff.learning.templates.LearningProcess that performs federated evaluation on clients. The learning process has the following methods inherited from tff.learning.templates.LearningProcess:

Each time next is called, the server model is broadcast to each client using a distributor. Each client evaluates the model and reports local unfinalized metrics. The local unfinalized metrics are then aggregated and finalized at server using the metrics aggregator. Both current round and total rounds metrics will be produced. There are no update of the server model during the evaluation process.

Args
model_fn A no-arg function that returns atff.learning.models.VariableModel, or an instance of atff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error.
model_distributor An optional tff.learning.templates.DistributionProcessthat broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENTS) and have empty state. If None, the server model is broadcast to the clients using the default tff.federated_broadcast.
metrics_aggregation_process An optional tff.templates.AggregationProcesswhich aggregates the local unfinalized metrics at clients to server and finalizes the metrics at server. The tff.templates.AggregationProcessaccumulates unfinalized metrics across round in the state, and produces a tuple of current round metrics and total rounds metrics in the result. If None, the tff.templates.AggregationProcess created by theSumThenFinalizeFactory with metric finalizers defined in the model is used.
loop_implementation Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details.
Returns
A tff.learning.templates.LearningProcess performs federated evaluation on clients, and returns updated state and metrics.
Raises
TypeError If any argument type mismatches.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2024-09-20 UTC.