View source on GitHub |
Gathers the values from input tensor based on per-row indices.
tfr.utils.gather_per_row(
inputs, indices
)
Example Usage:
scores = [[1., 3., 2.], [1., 2., 3.]]
indices = [[1, 2], [2, 1]]
tfr.utils.gather_per_row(scores, indices)
Returns [[3., 2.], [3., 2.]]
Returns | |
---|---|
A tensor of values gathered from inputs, of shape [batch_size, size] or [batch_size, size, feature_dims], depending on whether the input was 2D or 3D. |