On each replica, the input is split into split_count blocks along
split_dimension and send to the other replicas given group_assignment. After
receiving split_count - 1 blocks from other replicas, we concatenate the
blocks along concat_dimension as the output.
For example, suppose there are 2 TPU replicas:
replica 0 receives input: [[A, B]]
replica 1 receives input: [[C, D]]
A Tensor. Must be one of the following types: float32, float64, int32, uint8, int16, int8, complex64, int64, qint8, quint8, qint32, bfloat16, qint16, quint16, uint16, complex128, half, uint32, uint64, bool.
The local input to the sum.
group_assignment
A Tensor of type int32. An int32 tensor with shape
[num_groups, num_replicas_per_group]. group_assignment[i] represents the
replica ids in the ith subgroup.
concat_dimension
An int. The dimension number to concatenate.
split_dimension
An int. The dimension number to split.
split_count
An int.
The number of splits, this number must equal to the sub-group
size(group_assignment.get_shape()[1])
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-01-23 UTC."],[],[]]