Calculate a per-batch sparse categorical crossentropy loss.
tfm.nlp.losses.weighted_sparse_categorical_crossentropy_loss(
labels, predictions, weights=None, from_logits=False
)
This loss function assumes that the predictions are post-softmax.
Args:
labels: The labels to evaluate against. Should be a set of integer indices
ranging from 0 to (vocab_size-1).
predictions: The network predictions. Should have softmax already applied.
weights: An optional weight array of the same shape as the 'labels' array.
If None, all examples will be used.
from_logits: Whether the input predictions are logits.
Raises |
RuntimeError if the passed tensors do not have the same rank.
|