tff.simulation.baselines.BaselineTaskDatasets | TensorFlow Federated (original) (raw)
tff.simulation.baselines.BaselineTaskDatasets
Stay organized with collections Save and categorize content based on your preferences.
A convenience class for a task's data and preprocessing logic.
tff.simulation.baselines.BaselineTaskDatasets(
train_data: tff.simulation.datasets.ClientData,
test_data: tff.simulation.datasets.ClientData,
validation_data: Optional[tff.simulation.datasets.ClientData] = None,
train_preprocess_fn: Optional[tff.Computation] = None,
eval_preprocess_fn: Optional[tff.Computation] = None
)
Args | |
---|---|
train_data | A tff.simulation.datasets.ClientData for training. |
test_data | A tff.simulation.datasets.ClientData or a tf.data.Datasetfor computing test metrics. |
validation_data | An optional tff.simulation.datasets.ClientData or atf.data.Dataset for computing validation metrics. |
train_preprocess_fn | An optional callable accepting and returning atf.data.Dataset, used to perform dataset preprocessing for training. If set to None, we use the identity map for all train preprocessing. |
eval_preprocess_fn | An optional callable accepting and returning atf.data.Dataset, used to perform evaluation (eg. validation, testing) preprocessing. If None, evaluation preprocessing will be done via the identity map. |
Raises | |
---|---|
ValueError | If train_data and test_data have different element types after preprocessing with train_preprocess_fn and eval_preprocess_fn, or if validation_data is not None and has a different element type than the test data. |
Attributes | |
---|---|
train_data | A tff.simulation.datasets.ClientData for training. |
test_data | The test data for the baseline task. Can be atff.simulation.datasets.ClientData or a tf.data.Dataset. |
validation_data | The validation data for the baseline task. Can be one oftff.simulation.datasets.ClientData, tf.data.Dataset, or None if the task does not have a validation dataset. |
train_preprocess_fn | A callable mapping accepting and returntf.data.Dataset instances, used for preprocessing train datasets. Set toNone if no train preprocessing occurs for the task. |
eval_preprocess_fn | A callable mapping accepting and returntf.data.Dataset instances, used for preprocessing evaluation datasets. Set to None if no eval preprocessing occurs for the task. |
element_type_structure | A nested structure of tf.TensorSpec objects defining the type of the elements contained in datasets associated to this task. |
Methods
get_centralized_test_data
get_centralized_test_data() -> tf.data.Dataset
Returns a tf.data.Dataset of test data for the task.
If the baseline task has centralized data, then this method will return the centralized data after applying preprocessing. If the test data is federated, then this method will first amalgamate the client datasets into a single dataset, then apply preprocessing.
sample_train_clients
sample_train_clients(
num_clients: int, replace: bool = False, random_seed: Optional[int] = None
) -> list[tf.data.Dataset]
Samples training clients uniformly at random.
Args | |
---|---|
num_clients | A positive integer representing number of clients to be sampled. |
replace | Whether to sample with replacement. If set to False, thennum_clients cannot exceed the number of training clients in the associated train data. |
random_seed | An optional integer used to set a random seed for sampling. If no random seed is passed or the random seed is set to None, this will attempt to set the random seed according to the current system time (see numpy.random.RandomState for details). |
Returns |
---|
A list of tf.data.Dataset instances representing the client datasets. |
summary
summary(
print_fn: Callable[[str], Any] = print
)
Prints a summary of the train, test, and validation data.
The summary will be printed as a table containing information on the type of train, test, and validation data (ie. federated or centralized) and the number of clients each data structure has (if it is federated). For example, if the train data has 10 clients, and both the test and validation data are centralized, then this will print the following table:
Split |Dataset Type |Number of Clients |
=============================================
Train |Federated |10 |
Test |Centralized |N/A |
Validation |Centralized |N/A |
_____________________________________________
In addition, this will print two lines after the table indicating whether train and eval preprocessing functions were passed in. In the example above, if we passed in a train preprocessing function but no eval preprocessing function, it would also print the lines:
Train Preprocess Function: True
Eval Preprocess Function: False
To capture the summary, you can use a custom print function. For example, setting print_fn = summary_list.append
will cause each of the lines above to be appended to summary_list
.
Args | |
---|---|
print_fn | An optional callable accepting string inputs. Used to print each row of the summary. Defaults to print if not specified. |