tff.learning.metrics.FinalizeThenSampleFactory | TensorFlow Federated (original) (raw)
tff.learning.metrics.FinalizeThenSampleFactory
Stay organized with collections Save and categorize content based on your preferences.
Aggregation Factory that finalizes and then samples the metrics.
Inherits From: UnweightedAggregationFactory
tff.learning.metrics.FinalizeThenSampleFactory(
sample_size: int = 100
)
The created tff.templates.AggregationProcess finalizes each client's metrics locally, and then collects metrics from at most sample_size
clients at thetff.SERVER. If more than sample_size
clients participating, thensample_size
clients are sampled (by reservoir sampling algorithm); otherwise, all clients' metrics are collected. Sampling is done in a "per-client" manner, i.e., a client, once sampled, will contribute all its metrics to the final result.
The collected metrics samples at tff.SERVER has the same structure (i.e., same keys in a dictionary) as the client's local metrics, except that each leaf node contains a list of scalar metric values, where each value comes from a sampled client, e.g.,
sampled_metrics_at_server = {
'metric_a': [a1, a2, ...],
'metric_b': [b1, b2, ...],
...
}
where "a1" and "b1" are from the same client (similary for "a2" and "b2" etc).
Both "current round samples" and "total rounds samples" are returned, and and they both contain at most metrics from sample_size
clients. Sampling is done across the current round's participating clients (the result is "current round samples") or across all the participating clients so far (the result is "total rounds samples").
The next
function of the created tff.templates.AggregationProcess takes the state
and local unfinalized metrics reported from tff.CLIENTS, and returns a tff.templates.MeasuredProcessOutput object with the following properties:
state
: a dictionary of total rounds samples and the sampling metadata ( e.g., random values generated by the reservoir sampling algorithm).result
: a tuple of current round samples and total rounds samples.measurements
: the number of non-finite (NaN
orInf
values) leaves in the current round client values before sampling.
Example usage |
---|
sample_process = FinalizeThenSampleFactory(sample_size).create( metric_finalizers, local_unfinalized_metrics_type) eval_process = tff.learning.algorithms.build_fed_eval( model_fn=..., metrics_aggregation_process=sample_process, ...) state = eval_process.initialize() for i in range(num_rounds): output = eval_process.next(state, client_data_at_round_i) state = output.state current_round_samples, total_rounds_samples = output.result |
The created eval_process
can also be used intff.learning.programs.EvaluationManager.
Args | |
---|---|
sample_size | An integer specifying the number of clients sampled (by reservoir sampling algorithm). Metrics from the sampled clients are collected at the server, and this sample_size applies to current round and total rounds samples (see the class documentation for details). Default value is 100. |
Raises | |
---|---|
TypeError | If any argument type mismatches. |
ValueError | If sample_size is not positive. |
Methods
create
create(
metric_finalizers: Union[tff.learning.metrics.MetricFinalizersType, tff.learning.metrics.FunctionalMetricFinalizersType],
local_unfinalized_metrics_type: tff.types.StructWithPythonType
) -> tff.templates.AggregationProcess
Creates a tff.templates.AggregationProcess for metrics aggregation.
Args | |
---|---|
metric_finalizers | Either the result oftff.learning.models.VariableModel.metric_finalizers (an OrderedDictof callables) or thetff.learning.models.FunctionalModel.finalize_metrics method (a callable that takes an OrderedDict argument). If the former, the keys must be the same as the OrderedDict returned bytff.learning.models.VariableModel.report_local_unfinalized_metrics. If the later, the callable must compute over the same keyspace of the result returned bytff.learning.models.FunctionalModel.update_metrics_state. |
local_unfinalized_metrics_type | A tff.types.StructWithPythonType (withcollections.OrderedDict as the Python container) of a client's local unfinalized metrics. |
Returns |
---|
An instance of tff.templates.AggregationProcess. |
Raises | |
---|---|
TypeError | If any argument type mismatches; if the metric finalizers mismatch the type of local unfinalized metrics; if the initial unfinalized metrics mismatch the type of local unfinalized metrics. |