Computes the triplet loss with semi-hard negative mining.
@tf.function
tfa.losses.triplet_semihard_loss(
y_true: tfa.types.TensorLike
,
y_pred: tfa.types.TensorLike
,
margin: tfa.types.FloatTensorLike
= 1.0,
distance_metric: Union[str, Callable] = 'L2'
) -> tf.Tensor
Usage:
y_true = tf.convert_to_tensor([0, 0])
y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric="L2")
<tf.Tensor: shape=(), dtype=float32, numpy=2.4142137>
# Calling with callable `distance_metric`
distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric=distance_metric)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
Args |
y_true
|
1-D integer Tensor with shape [batch_size] of
multiclass integer labels.
|
y_pred
|
2-D float Tensor of embedding vectors. Embeddings should
be l2 normalized.
|
margin
|
Float, margin term in the loss definition.
|
distance_metric
|
str or a Callable that determines distance metric.
Valid strings are "L2" for l2-norm distance,
"squared-L2" for squared l2-norm distance,
and "angular" for cosine similarity.
A Callable should take a batch of embeddings as input and
return the pairwise distance matrix.
|
Returns |
triplet_loss
|
float scalar with dtype of y_pred .
|