Scatters updates
into a tensor of shape shape
according to indices
.
tf.scatter_nd(
indices, updates, shape, name=None
)
Update the input tensor by scattering sparse updates
according to individual values at the specified indices
.
This op returns an output
tensor with the shape
you specify. This op is the
inverse of the tf.gather_nd
operator which extracts values or slices from a
given tensor.
This operation is similar to tf.tensor_scatter_add
, except that the tensor is
zero-initialized. Calling tf.scatter_nd(indices, values, shape)
is identical to calling
tf.tensor_scatter_add(tf.zeros(shape, values.dtype), indices, values)
.
If indices
contains duplicates, the duplicate values
are accumulated
(summed).
indices
is an integer tensor of shape shape
. The last dimension
of indices
can be at most the rank of shape
:
indices.shape[-1] <= shape.rank
The last dimension of indices
corresponds to indices of elements
(if indices.shape[-1] = shape.rank
) or slices
(if indices.shape[-1] < shape.rank
) along dimension indices.shape[-1]
of
shape
.
updates
is a tensor with shape:
indices.shape[:-1] + shape[indices.shape[-1]:]
The simplest form of the scatter op is to insert individual elements in a tensor by index. Consider an example where you want to insert 4 scattered elements in a rank-1 tensor with 8 elements.
In Python, this scatter operation would look like this:
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
The resulting tensor would look like this:
[0, 11, 0, 10, 9, 0, 0, 12]
You can also insert entire slices of a higher rank tensor all at once. For example, you can insert two slices in the first dimension of a rank-3 tensor with two matrices of new values.
In Python, this scatter operation would look like this:
indices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]],
[[5, 5, 5, 5], [6, 6, 6, 6],
[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
scatter = tf.scatter_nd(indices, updates, shape)
print(scatter)
The resulting tensor would look like this:
[[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],
[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
[[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]]
Note that on CPU, if an out of bound index is found, an error is returned. On GPU, if an out of bound index is found, the index is ignored.
Returns | |
---|---|
A Tensor . Has the same type as updates .
|