View source on GitHub |
Return the elements where condition
is True
(multiplexing x
and y
).
tf.where(
condition, x=None, y=None, name=None
)
This operator has two modes: in one mode both x
and y
are provided, in
another mode neither are provided. condition
is always expected to be a
tf.Tensor
of type bool
.
Retrieving indices of True
elements
If x
and y
are not provided (both are None):
tf.where
will return the indices of condition
that are True
, in
the form of a 2-D tensor with shape (n, d).
(Where n is the number of matching indices in condition
,
and d is the number of dimensions in condition
).
Indices are output in row-major order.
tf.where([True, False, False, True])
<tf.Tensor: shape=(2, 1), dtype=int64, numpy=
array([[0],
[3]])>
tf.where([[True, False], [False, True]])
<tf.Tensor: shape=(2, 2), dtype=int64, numpy=
array([[0, 0],
[1, 1]])>
tf.where([[[True, False], [False, True], [True, True]]])
<tf.Tensor: shape=(4, 3), dtype=int64, numpy=
array([[0, 0, 0],
[0, 1, 1],
[0, 2, 0],
[0, 2, 1]])>
Multiplexing between x
and y
If x
and y
are provided (both have non-None values):
tf.where
will choose an output shape from the shapes of condition
, x
,
and y
that all three shapes are
broadcastable to.
The condition
tensor acts as a mask that chooses whether the corresponding
element / row in the output should be taken from x
(if the element in condition
is True) or y
(if it is false).
tf.where([True, False, False, True], [1,2,3,4], [100,200,300,400])
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 200, 300, 4],
dtype=int32)>
tf.where([True, False, False, True], [1,2,3,4], [100])
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 4],
dtype=int32)>
tf.where([True, False, False, True], [1,2,3,4], 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 4],
dtype=int32)>
tf.where([True, False, False, True], 1, 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([ 1, 100, 100, 1],
dtype=int32)>
tf.where(True, [1,2,3,4], 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([1, 2, 3, 4],
dtype=int32)>
tf.where(False, [1,2,3,4], 100)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([100, 100, 100, 100],
dtype=int32)>
Note that if the gradient of either branch of the tf.where generates a NaN, then the gradient of the entire tf.where will be NaN. This is because the gradient calculation for tf.where combines the two branches, for performance reasons.
A workaround is to use an inner tf.where to ensure the function has no asymptote, and to avoid computing a value whose gradient is NaN by replacing dangerous inputs with safe inputs.
Instead of this,
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
y = tf.where(x < 1., 0., 1. / x)
print(tape.gradient(y, x))
tf.Tensor(nan, shape=(), dtype=float32)
Although, the 1. / x
values are never used, its gradient is a NaN when x =
- Instead, we should guard that with another
tf.where
x = tf.constant(0., dtype=tf.float32)
with tf.GradientTape() as tape:
tape.watch(x)
safe_x = tf.where(tf.equal(x, 0.), 1., x)
y = tf.where(x < 1., 0., 1. / safe_x)
print(tape.gradient(y, x))
tf.Tensor(0.0, shape=(), dtype=float32)
Args | |
---|---|
condition
|
A tf.Tensor of type bool
|
x
|
If provided, a Tensor which is of the same type as y , and has a shape
broadcastable with condition and y .
|
y
|
If provided, a Tensor which is of the same type as x , and has a shape
broadcastable with condition and x .
|
name
|
A name of the operation (optional). |
Returns | |
---|---|
If x and y are provided:
A Tensor with the same type as x and y , and shape that
is broadcast from condition , x , and y .
Otherwise, a Tensor with shape (num_true, dim_size(condition)) .
|
Raises | |
---|---|
ValueError
|
When exactly one of x or y is non-None, or the shapes
are not all broadcastable.
|