tfa.losses.sparsemax_loss

Sparsemax loss function 1.

Computes the generalized multi-label classification loss for the sparsemax function. The implementation is a reformulation of the original loss function such that it uses the sparsemax probability output instead of the internal variable. However, the output is identical to the original loss function.

logits A Tensor. Must be one of the following types: float32, float64.
sparsemax A Tensor. Must have the same type as logits.
labels A Tensor. Must have the same type as logits.
name A name for the operation (optional).

A Tensor. Has the same type as logits.