View source on GitHub |
Computes mixture EM loss between y_true
and y_pred
.
tfr.keras.losses.MixtureEMLoss(
reduction: tf.losses.Reduction = tf.losses.Reduction.AUTO,
name: Optional[str] = None,
lambda_weight: Optional[losses_impl._LambdaWeight] = None,
temperature: float = 1.0,
alpha: float = 1.0,
ragged: bool = False
)
Implementation of mixture Expectation-Maximization loss (Yan et al, 2018). This loss assumes that the clicks in a session are generated by one of mixture models.
Standalone usage:
y_true = [[1., 0.]]
y_pred = [[[0.6, 0.9], [0.8, 0.2]]]
loss = tfr.keras.losses.MixtureEMLoss()
loss(y_true, y_pred).numpy()
1.3198698
# Using ragged tensors
y_true = tf.ragged.constant([[1., 0.], [0., 1., 0.]])
y_pred = tf.ragged.constant([[[0.6, 0.9], [0.8, 0.2]],
[[0.5, 0.9], [0.8, 0.2], [0.4, 0.8]]])
loss = tfr.keras.losses.MixtureEMLoss(ragged=True)
loss(y_true, y_pred).numpy()
1.909512
Usage with the compile()
API:
model.compile(optimizer='sgd', loss=tfr.keras.losses.MixtureEMLoss())
References | |
---|---|
Args | |
---|---|
reduction
|
(Optional) The tf.keras.losses.Reduction to use (see
tf.keras.losses.Loss ).
|
name
|
(Optional) The name for the op. |
lambda_weight
|
(Optional) A lambdaweight to apply to the loss. Can be one
of tfr.keras.losses.DCGLambdaWeight ,
tfr.keras.losses.NDCGLambdaWeight , or,
tfr.keras.losses.PrecisionLambdaWeight .
|
temperature
|
(Optional) The temperature to use for scaling the logits. |
alpha
|
(Optional) The smooth factor of the probability. |
ragged
|
(Optional) If True, this loss will accept ragged tensors. If False, this loss will accept dense tensors. |
Methods
from_config
@classmethod
from_config( config, custom_objects=None )
Instantiates a Loss
from its config (output of get_config()
).
Args | |
---|---|
config
|
Output of get_config() .
|
Returns | |
---|---|
A Loss instance.
|
get_config
get_config() -> Dict[str, Any]
Returns the config dictionary for a Loss
instance.
__call__
__call__(
y_true: tfr.keras.model.TensorLike
,
y_pred: tfr.keras.model.TensorLike
,
sample_weight: Optional[utils.TensorLike] = None
) -> tf.Tensor
See tf.keras.losses.Loss.