tf.keras.losses.Reduction
Stay organized with collections
Save and categorize content based on your preferences.
Types of loss reduction.
Contains the following values:
AUTO
: Indicates that the reduction option will be determined by the
usage context. For almost all cases this uses SUM_OVER_BATCH_SIZE
.
When used with tf.distribute.Strategy
, outside of built-in training
loops such as tf.keras
compile
and fit
, we expect reduction
value to be SUM
or NONE
. Using AUTO
in that case will raise an
error.
NONE
: No additional reduction is applied to the output of the
wrapped loss function. When non-scalar losses are returned to Keras
functions like fit
/evaluate
, the unreduced vector loss is passed to
the optimizer but the reported loss will be a scalar value.
Caution: Verify the shape of the outputs when using Reduction.NONE
.
The builtin loss functions wrapped by the loss classes reduce one
dimension (axis=-1
, or axis
if specified by loss function).
Reduction.NONE
just means that no additional reduction is applied
by the class wrapper. For categorical losses with an example input shape
of [batch, W, H, n_classes]
the n_classes
dimension is reduced. For
pointwise losses you must include a dummy axis so that [batch, W, H, 1]
is reduced to [batch, W, H]
. Without the dummy axis [batch, W, H]
will be incorrectly reduced to [batch, W]
.
SUM
: Scalar sum of weighted losses.
SUM_OVER_BATCH_SIZE
: Scalar SUM
divided by number of elements in
losses. This reduction type is not supported when used with
tf.distribute.Strategy
outside of built-in training loops like
tf.keras
compile
/fit
.
You can implement 'SUM_OVER_BATCH_SIZE' using global batch size like:
with strategy . scope ():
loss_obj = tf . keras . losses . CategoricalCrossentropy (
reduction = tf . keras . losses . Reduction . NONE )
....
loss = tf . reduce_sum ( loss_obj ( labels , predictions )) *
( 1. / global_batch_size )
Please see the custom training guide for more
details on this.
Methods
all
View source
@classmethod
all ()
validate
View source
@classmethod
validate (
key
)
Class Variables
AUTO
'auto'
NONE
'none'
SUM
'sum'
SUM_OVER_BATCH_SIZE
'sum_over_batch_size'
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License , and code samples are licensed under the Apache 2.0 License . For details, see the Google Developers Site Policies . Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license .
Last updated 2024-01-23 UTC.
[{
"type": "thumb-down",
"id": "missingTheInformationINeed",
"label":"Missing the information I need"
},{
"type": "thumb-down",
"id": "tooComplicatedTooManySteps",
"label":"Too complicated / too many steps"
},{
"type": "thumb-down",
"id": "outOfDate",
"label":"Out of date"
},{
"type": "thumb-down",
"id": "samplesCodeIssue",
"label":"Samples / code issue"
},{
"type": "thumb-down",
"id": "otherDown",
"label":"Other"
}]
[{
"type": "thumb-up",
"id": "easyToUnderstand",
"label":"Easy to understand"
},{
"type": "thumb-up",
"id": "solvedMyProblem",
"label":"Solved my problem"
},{
"type": "thumb-up",
"id": "otherUp",
"label":"Other"
}]
{"lastModified": "Last updated 2024-01-23 UTC."}
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-01-23 UTC."]]