Creates a train_step that evaluates the gradients and returns the loss.
tf_agents.utils.eager_utils.create_train_step(
loss,
optimizer,
global_step=_USE_GLOBAL_STEP,
total_loss_fn=None,
update_ops=None,
variables_to_train=None,
transform_grads_fn=None,
summarize_gradients=False,
gate_gradients=tf.compat.v1.train.Optimizer.GATE_OP,
aggregation_method=None,
check_numerics=True
)
Args |
loss
|
A (possibly nested tuple of) Tensor or function representing the
loss.
|
optimizer
|
A tf.Optimizer to use for computing the gradients.
|
global_step
|
A Tensor representing the global step variable. If left as
_USE_GLOBAL_STEP , then tf.train.get_or_create_global_step() is used.
|
total_loss_fn
|
Function to call on loss value to access the final item to
minimize.
|
update_ops
|
An optional list of updates to execute. If update_ops is
None , then the update ops are set to the contents of the
tf.GraphKeys.UPDATE_OPS collection. If update_ops is not None , but
it doesn't contain all of the update ops in tf.GraphKeys.UPDATE_OPS , a
warning will be displayed.
|
variables_to_train
|
an optional list of variables to train. If None, it will
default to all tf.trainable_variables().
|
transform_grads_fn
|
A function which takes a single argument, a list of
gradient to variable pairs (tuples), performs any requested gradient
updates, such as gradient clipping or multipliers, and returns the updated
list.
|
summarize_gradients
|
Whether or not add summaries for each gradient.
|
gate_gradients
|
How to gate the computation of gradients. See tf.Optimizer.
|
aggregation_method
|
Specifies the method used to combine gradient terms.
Valid values are defined in the class AggregationMethod .
|
check_numerics
|
Whether or not we apply check_numerics.
|
Returns |
In graph mode: A (possibly nested tuple of) Tensor that when evaluated,
calculates the current loss, computes the gradients, applies the
optimizer, and returns the current loss.
In eager mode: A lambda function that when is called, calculates the loss,
then computes and applies the gradients and returns the original
loss values.
|
Raises |
ValueError
|
if loss is not callable.
|
RuntimeError
|
if resource variables are not enabled.
|