tff.learning.algorithms.build_personalization_eval_computation | TensorFlow Federated (original) (raw)
tff.learning.algorithms.build_personalization_eval_computation
Stay organized with collections Save and categorize content based on your preferences.
Builds the TFF computation for evaluating personalization strategies.
tff.learning.algorithms.build_personalization_eval_computation(
model_fn: Callable[[], tff.learning.models.VariableModel],
personalize_fn_dict: Mapping[str, Callable[[], tff.learning.models.VariableModel]],
baseline_evaluate_fn: Callable[[variable.VariableModel, tf.data.Dataset], _MetricsType],
max_num_clients: int = 100
) -> tff.Computation
The returned TFF computation broadcasts model weights from tff.SERVER totff.CLIENTS. Each client evaluates the personalization strategies given inpersonalize_fn_dict
. Evaluation metrics from at most max_num_clients
participating clients are collected to the server.
Args | |
---|---|
model_fn | A no-arg function that returns atff.learning.models.VariableModel. This method must not capture TensorFlow tensors or variables and use them. The model must be constructed entirely from scratch on each invocation, returning the same pre-constructed model each call will result in an error. |
personalize_fn_dict | An OrderedDict that maps a string (representing a strategy name) to a no-argument function that returns a tf.function. Each tf.function represents a personalization strategy - it accepts atff.learning.models.VariableModel (with weights already initialized to the given model weights when users invoke the returned TFF computation), an unbatched tf.data.Dataset for train, and an unbatchedtf.data.Dataset for test, trains a personalized model, and returns the evaluation metrics. The evaluation metrics are represented as anOrderedDict (or a nested OrderedDict) of string metric names to scalar tf.Tensors. |
baseline_evaluate_fn | A tf.function that accepts atff.learning.models.VariableModel (with weights already initialized to the provided model weights when users invoke the returned TFF computation), and an unbatched tf.data.Dataset, evaluates the model on the dataset, and returns the evaluation metrics. The evaluation metrics are represented as an OrderedDict (or a nested OrderedDict) ofstring metric names to scalar tf.Tensors. This function is only used to compute the baseline metrics of the initial model. |
max_num_clients | A positive int specifying the maximum number of clients to collect metrics in a round (default is 100). The clients are sampled without replacement. For each sampled client, all the personalization metrics from this client will be collected. If the number of participating clients in a round is smaller than this value, then metrics from all clients will be collected. |
Returns |
---|
A federated tff.Computation with the functional type signature(<model_weights@SERVER, input@CLIENTS> -> personalization_metrics@SERVER): model_weights is a tff.learning.models.ModelWeights. Each client's input is an OrderedDict of two required keystrain_data and test_data; each key is mapped to an unbatchedtf.data.Dataset. personalization_metrics is an OrderedDict that maps a key 'baseline_metrics' to the evaluation metrics of the initial model (computed by baseline_evaluate_fn), and maps keys (strategy names) inpersonalize_fn_dict to the evaluation metrics of the corresponding personalization strategies. |
Raises | |
---|---|
TypeError | If arguments are of the wrong types. |
ValueError | If baseline_metrics is used as a key in personalize_fn_dict. |
ValueError | If max_num_clients is not positive. |
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-09-20 UTC.