tf.raw_ops.CrossReplicaSum | TensorFlow v2.16.1 (original) (raw)
An Op to sum inputs across replicated TPU instances.
View aliases
Compat aliases for migration
SeeMigration guide for more details.
tf.compat.v1.raw_ops.CrossReplicaSum
tf.raw_ops.CrossReplicaSum(
input, group_assignment, name=None
)
Each instance supplies its own input.
For example, suppose there are 8 TPU instances: [A, B, C, D, E, F, G, H]
. Passing group_assignment=[[0,2,4,6],[1,3,5,7]]
sets A, C, E, G
as group 0, and B, D, F, H
as group 1. Thus we get the outputs:[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]
.
Args | |
---|---|
input | A Tensor. Must be one of the following types: half, bfloat16, float32, float64, int32, uint32. The local input to the sum. |
group_assignment | A Tensor of type int32. An int32 tensor with shape [num_groups, num_replicas_per_group]. group_assignment[i] represents the replica ids in the ith subgroup. |
name | A name for the operation (optional). |
Returns |
---|
A Tensor. Has the same type as input. |