tf.distribute.experimental.ValueContext  |  TensorFlow v2.16.1 (original) (raw)

tf.distribute.experimental.ValueContext

A class wrapping information needed by a distribute function.

tf.distribute.experimental.ValueContext(
    replica_id_in_sync_group=0, num_replicas_in_sync=1
)

This is a context class that is passed to the value_fn instrategy.experimental_distribute_values_from_function and contains information about the compute replicas. The num_replicas_in_sync andreplica_id can be used to customize the value on each replica.

Example usage:

  1. Directly constructed.
    def value_fn(context):
    return context.replica_id_in_sync_group/context.num_replicas_in_sync
    context = tf.distribute.experimental.ValueContext(
    replica_id_in_sync_group=2, num_replicas_in_sync=4)
    per_replica_value = value_fn(context)
    per_replica_value
    0.5
  2. Passed in by experimental_distribute_values_from_function.
    strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
    def value_fn(value_context):
    return value_context.num_replicas_in_sync
    distributed_values = (
    strategy.experimental_distribute_values_from_function(
    value_fn))

local_result = strategy.experimental_local_results(distributed_values)
local_result
(2, 2)

Args
replica_id_in_sync_group the current replica_id, should be an int in [0,num_replicas_in_sync).
num_replicas_in_sync the number of replicas that are in sync.
Attributes
num_replicas_in_sync Returns the number of compute replicas in sync.
replica_id_in_sync_group Returns the replica ID.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.

Last updated 2024-04-26 UTC.