tff.learning.metrics.FinalizeThenSampleFactory  |  TensorFlow Federated (original) (raw)

tff.learning.metrics.FinalizeThenSampleFactory

Stay organized with collections Save and categorize content based on your preferences.

Aggregation Factory that finalizes and then samples the metrics.

Inherits From: UnweightedAggregationFactory

tff.learning.metrics.FinalizeThenSampleFactory(
    sample_size: int = 100
)

The created tff.templates.AggregationProcess finalizes each client's metrics locally, and then collects metrics from at most sample_size clients at thetff.SERVER. If more than sample_size clients participating, thensample_size clients are sampled (by reservoir sampling algorithm); otherwise, all clients' metrics are collected. Sampling is done in a "per-client" manner, i.e., a client, once sampled, will contribute all its metrics to the final result.

The collected metrics samples at tff.SERVER has the same structure (i.e., same keys in a dictionary) as the client's local metrics, except that each leaf node contains a list of scalar metric values, where each value comes from a sampled client, e.g.,

  sampled_metrics_at_server = {
      'metric_a': [a1, a2, ...],
      'metric_b': [b1, b2, ...],
      ...
  }

where "a1" and "b1" are from the same client (similary for "a2" and "b2" etc).

Both "current round samples" and "total rounds samples" are returned, and and they both contain at most metrics from sample_size clients. Sampling is done across the current round's participating clients (the result is "current round samples") or across all the participating clients so far (the result is "total rounds samples").

The next function of the created tff.templates.AggregationProcess takes the state and local unfinalized metrics reported from tff.CLIENTS, and returns a tff.templates.MeasuredProcessOutput object with the following properties:

Example usage
sample_process = FinalizeThenSampleFactory(sample_size).create( metric_finalizers, local_unfinalized_metrics_type) eval_process = tff.learning.algorithms.build_fed_eval( model_fn=..., metrics_aggregation_process=sample_process, ...) state = eval_process.initialize() for i in range(num_rounds): output = eval_process.next(state, client_data_at_round_i) state = output.state current_round_samples, total_rounds_samples = output.result

The created eval_process can also be used intff.learning.programs.EvaluationManager.

Args
sample_size An integer specifying the number of clients sampled (by reservoir sampling algorithm). Metrics from the sampled clients are collected at the server, and this sample_size applies to current round and total rounds samples (see the class documentation for details). Default value is 100.
Raises
TypeError If any argument type mismatches.
ValueError If sample_size is not positive.

Methods

create

View source

create(
    metric_finalizers: Union[tff.learning.metrics.MetricFinalizersType, tff.learning.metrics.FunctionalMetricFinalizersType],
    local_unfinalized_metrics_type: tff.types.StructWithPythonType
) -> tff.templates.AggregationProcess

Creates a tff.templates.AggregationProcess for metrics aggregation.

Args
metric_finalizers Either the result oftff.learning.models.VariableModel.metric_finalizers (an OrderedDictof callables) or thetff.learning.models.FunctionalModel.finalize_metrics method (a callable that takes an OrderedDict argument). If the former, the keys must be the same as the OrderedDict returned bytff.learning.models.VariableModel.report_local_unfinalized_metrics. If the later, the callable must compute over the same keyspace of the result returned bytff.learning.models.FunctionalModel.update_metrics_state.
local_unfinalized_metrics_type A tff.types.StructWithPythonType (withcollections.OrderedDict as the Python container) of a client's local unfinalized metrics.
Returns
An instance of tff.templates.AggregationProcess.
Raises
TypeError If any argument type mismatches; if the metric finalizers mismatch the type of local unfinalized metrics; if the initial unfinalized metrics mismatch the type of local unfinalized metrics.