tfa.activations.sparsemax

Sparsemax activation function.

For each batch \(i\), and class \(j\), compute sparsemax activation function:

\[ \mathrm{sparsemax}(x)[i, j] = \max(\mathrm{logits}[i, j] - \tau(\mathrm{logits}[i, :]), 0). \]

See From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification.

Usage:

x = tf.constant([[-1.0, 0.0, 1.0], [-5.0, 1.0, 2.0]])
tfa.activations.sparsemax(x)
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[0., 0., 1.],
       [0., 0., 1.]], dtype=float32)>

logits A Tensor.
axis int, axis along which the sparsemax operation is applied.

A Tensor, output of sparsemax transformation. Has the same type and shape as logits.

ValueError In case dim(logits) == 1.