tfds.split_for_jax_process | TensorFlow Datasets (original) (raw)
Returns the subsplit of the data for the process.
tfds.split_for_jax_process(
split: str,
*,
process_index: tfds.typing.Dim = None,
process_count: tfds.typing.Dim = None,
drop_remainder: bool = False
) -> tfds.typing.SplitArg
In distributed setting, all process/hosts should get a non-overlapping, equally sized slice of the entire data. This function takes as input a split and extracts the slice for the current process index.
Usage:
tfds.load(..., split=tfds.split_for_jax_process('train'))
This funtion is an alias for:
tfds.even_splits(split, n=jax.process_count())[jax.process_index()]
By default, if examples can't be evenly distributed across processes, you can drop extra examples with drop_remainder=True
.
Args | |
---|---|
split | Split to distribute across host (e.g. train[75%:],train[:800]+validation[:100]). |
process_index | Process index in [0, count). Defaults tojax.process_index(). |
process_count | Number of processes. Defaults to jax.process_count(). |
drop_remainder | Drop examples if the number of examples in the datasets is not evenly divisible by n. If False, examples are distributed evenly across subsplits, starting by the first. For example, if there is 11 examples with n=3, splits will contain [4, 4, 3] examples respectivelly. |
Returns | |
---|---|
subsplit | The sub-split of the given split for the currentprocess_index. |