tff.learning.algorithms.build_fed_eval | TensorFlow Federated (original) (raw)
tff.learning.algorithms.build_fed_eval
Stay organized with collections Save and categorize content based on your preferences.
Builds a learning process that performs federated evaluation.
tff.learning.algorithms.build_fed_eval(
model_fn: Union[Callable[[], tff.learning.models.VariableModel], tff.learning.models.FunctionalModel],
model_distributor: Optional[tff.learning.templates.DistributionProcess] = None,
metrics_aggregation_process: Optional[tff.templates.AggregationProcess] = None,
loop_implementation: tff.learning.LoopImplementation = tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.LearningProcess
Used in the notebooks
Used in the tutorials |
---|
Federated Learning for Image Classification |
This function creates a tff.learning.templates.LearningProcess that performs federated evaluation on clients. The learning process has the following methods inherited from tff.learning.templates.LearningProcess:
initialize
: A tff.Computation with type signature( -> S@SERVER)
, whereS
is a tff.learning.templates.LearningAlgorithmStaterepresenting the initial state of the server.next
: A tff.Computation with type signature(<S@SERVER, {B*}@CLIENTS> -> <L@SERVER>)
whereS
is aLearningAlgorithmState
whose type matches that of the output ofinitialize
, and{B*}@CLIENTS
represents the client datasets, whereB
is the type of a single batch. The outputL
contains the updated server state, as well as aggregated metrics at the server, including client evaluation metrics and any other metrics from distribution and aggregation processes.get_model_weights
: A tff.Computation with type signature(S -> M)
, whereS
is a tff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
andnext
, andM
represents the type of the model weights used during evaluation.set_model_weights
: A tff.Computation with type signature(<S, M> -> S)
, whereS
is atff.learning.templates.LearningAlgorithmState whose type matches the output ofinitialize
andM
represents the type of the model weights used during evaluation.
Each time next
is called, the server model is broadcast to each client using a distributor. Each client evaluates the model and reports local unfinalized metrics. The local unfinalized metrics are then aggregated and finalized at server using the metrics aggregator. Both current round and total rounds metrics will be produced. There are no update of the server model during the evaluation process.
Args | |
---|---|
model_fn | A no-arg function that returns atff.learning.models.VariableModel, or an instance of atff.learning.models.FunctionalModel. When passing a callable, the callable must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. |
model_distributor | An optional tff.learning.templates.DistributionProcessthat broadcasts the model weights on the server to the clients. It must support the signature (input_values@SERVER -> output_values@CLIENTS) and have empty state. If None, the server model is broadcast to the clients using the default tff.federated_broadcast. |
metrics_aggregation_process | An optional tff.templates.AggregationProcesswhich aggregates the local unfinalized metrics at clients to server and finalizes the metrics at server. The tff.templates.AggregationProcessaccumulates unfinalized metrics across round in the state, and produces a tuple of current round metrics and total rounds metrics in the result. If None, the tff.templates.AggregationProcess created by theSumThenFinalizeFactory with metric finalizers defined in the model is used. |
loop_implementation | Changes the implementation of the training loop generated. See tff.learning.LoopImplementation for more details. |
Returns |
---|
A tff.learning.templates.LearningProcess performs federated evaluation on clients, and returns updated state and metrics. |
Raises | |
---|---|
TypeError | If any argument type mismatches. |
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.