tf.distribute.DistributedValues | TensorFlow v2.16.1 (original) (raw)
Base class for representing distributed values.
A subclass instance of tf.distribute.DistributedValues is created when creating variables within a distribution strategy, iterating atf.distribute.DistributedDataset or through tf.distribute.Strategy.run. This base class should never be instantiated directly.tf.distribute.DistributedValues contains a value per replica. Depending on the subclass, the values could either be synced on update, synced on demand, or never synced.
Two representative types of tf.distribute.DistributedValues aretf.types.experimental.PerReplica
and tf.types.experimental.Mirrored
values.
PerReplica
values exist on the worker devices, with a different value for each replica. They are produced by iterating through a distributed dataset returned by tf.distribute.Strategy.experimental_distribute_dataset (Example 1, below) and tf.distribute.Strategy.distribute_datasets_from_function. They are also the typical result returned by tf.distribute.Strategy.run (Example 2).
Mirrored
values are like PerReplica
values, except we know that the value on all replicas are the same. Mirrored
values are kept synchronized by the distribution strategy in use, while PerReplica
values are left unsynchronized. Mirrored
values typically represent model weights. We can safely read a Mirrored
value in a cross-replica context by using the value on any replica, while PerReplica values should not be read or manipulated in a cross-replica context."
tf.distribute.DistributedValues can be reduced via strategy.reduce
to obtain a single value across replicas (Example 4), used as input intotf.distribute.Strategy.run (Example 3), or collected to inspect the per-replica values using tf.distribute.Strategy.experimental_local_results(Example 5).
Example usages:
- Created from a tf.distribute.DistributedDataset:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
distributed_values
PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
1: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>
}
- Returned by
run
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
ctx = tf.distribute.get_replica_context()
return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)
distributed_values
PerReplica:{
0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}
- As input into
run
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input):
return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))
updated_value
PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>,
1: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([7.], dtype=float32)>
}
- As input into
reduce
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
distributed_values,
axis = 0)
reduced_value
<tf.Tensor: shape=(), dtype=float32, numpy=11.0>
- How to inspect per-replica values locally:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(
distributed_values)
per_replica_values
(<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)