tff.learning.add_debug_measurements  |  TensorFlow Federated (original) (raw)

tff.learning.add_debug_measurements

Stay organized with collections Save and categorize content based on your preferences.

Adds measurements suitable for debugging learning processes.

tff.learning.add_debug_measurements(
    aggregation_factory: _AggregationFactory
) -> _AggregationFactory

This will wrap a tff.aggregator.AggregationFactory as a new factory that will produce additional measurements useful for debugging learning processes. The underlying aggregation of client values will remain unchanged.

These measurements generally concern the norm of the client updates, and the norm of the aggregated server update. The implicit weighting will be determined by aggregation_factory: If this is weighted, then the debugging measurements will use this weighting when computing averages. If it is unweighted, the debugging measurements will use uniform weighting.

The client measurements are:

The standard deviation we report is the square root of the unbiasedvariance. The server measurements are:

In the above, an "entry" means any coordinate across all tensors in the structure. For example, suppose that we have client structures before aggregation:

If we use unweighted averaging, then the aggregate client update will be the structure [[-1, -3, -2], [1]]. The maximum entry is 1, the minimum entry is -3, and the euclidean norm is sqrt(15).

Args
aggregation_factory A tff.aggregators.AggregationFactory. Can be weighted or unweighted.
Returns
A tff.aggregators.AggregationFactory.