View source on GitHub |
Pads the invalid entries by valid ones and returns the nd_indices.
tfr.utils.padded_nd_indices(
is_valid, shuffle=False, seed=None
)
For example, when we have a batch_size = 1 and list_size = 3. Only the first 2 entries are valid. We have:
is_valid = [[True, True, False]]
nd_indices, mask = padded_nd_indices(is_valid)
nd_indices has a shape [1, 3, 2] and mask has a shape [1, 3].
nd_indices = [[[0, 0], [0, 1], [0, 0]]]
mask = [[True, True, False]]
nd_indices can be used by gather_nd on a Tensor t
padded_t = tf.gather_nd(t, nd_indices)
and get the following Tensor with first 2 dims are [1, 3]:
padded_t = [[t(0, 0), t(0, 1), t(0, 0)]]
Args | |
---|---|
is_valid
|
A boolean Tensor for entry validity with shape [batch_size,
list_size].
|
shuffle
|
A boolean that indicates whether valid indices should be shuffled. |
seed
|
Random seed for shuffle. |
Returns | |
---|---|
A tuple of Tensors (nd_indices, mask). The first has shape [batch_size, list_size, 2] and it can be used in gather_nd or scatter_nd. The second has the shape of [batch_size, list_size] with value True for valid indices. |