tf.compat.v1.train.cosine_decay
bookmark_borderbookmark
Stay organized with collections
Save and categorize content based on your preferences.
Applies cosine decay to the learning rate.
tf.compat.v1.train.cosine_decay(
learning_rate, global_step, decay_steps, alpha=0.0, name=None
)
When training a model, it is often recommended to lower the learning rate as
the training progresses. This function applies a cosine decay function
to a provided initial learning rate. It requires a global_step
value to
compute the decayed learning rate. You can just pass a TensorFlow variable
that you increment at each training step.
The function returns the decayed learning rate. It is computed as:
global_step = min(global_step, decay_steps)
cosine_decay = 0.5 * (1 + cos(pi * global_step / decay_steps))
decayed = (1 - alpha) * cosine_decay + alpha
decayed_learning_rate = learning_rate * decayed
Example usage:
decay_steps = 1000
lr_decayed = cosine_decay(learning_rate, global_step, decay_steps)
Args |
learning_rate
|
A scalar float32 or float64 Tensor or a Python number.
The initial learning rate.
|
global_step
|
A scalar int32 or int64 Tensor or a Python number. Global
step to use for the decay computation.
|
decay_steps
|
A scalar int32 or int64 Tensor or a Python number. Number
of steps to decay over.
|
alpha
|
A scalar float32 or float64 Tensor or a Python number. Minimum
learning rate value as a fraction of learning_rate.
|
name
|
String. Optional name of the operation. Defaults to 'CosineDecay'.
|
Returns |
A scalar Tensor of the same type as learning_rate . The decayed
learning rate.
|
Raises |
ValueError
|
if global_step is not supplied.
|
eager compatibility
When eager execution is enabled, this function returns a function which in
turn returns the decayed learning rate Tensor. This can be useful for changing
the learning rate value across different invocations of optimizer functions.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-01-23 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-01-23 UTC."],[],[]]