Manages all the learning details needed when training an agent.
tf_agents.train.Learner(
root_dir,
train_step,
agent,
experience_dataset_fn=None,
after_train_strategy_step_fn=None,
triggers=None,
checkpoint_interval=100000,
summary_interval=1000,
max_checkpoints_to_keep=3,
use_kwargs_in_agent_train=False,
strategy=None,
run_optimizer_variable_init=True,
use_reverb_v2=False,
direct_sampling=False,
experience_dataset_options=None,
strategy_run_options=None,
summary_root_dir=None
)
Used in the notebooks
These include |
- Using distribution strategies correctly
- Summaries
- Checkpoints
- Minimizing entering/exiting TF context:
Especially in the case of TPUs scheduling a single TPU program to
perform multiple train steps is critical for performance.
- Generalizes the train call to be done correctly across CPU, GPU, or TPU
executions managed by DistributionStrategies. This uses
strategy.run and
then makes sure to do a reduce operation over the LossInfo returned by
the agent.
|
Args |
root_dir
|
Main directory path where checkpoints, saved_models, and
summaries (if summary_dir is not specified) will be written to.
|
train_step
|
a scalar tf.int64 tf.Variable which will keep track of the
number of train steps. This is used for artifacts created like
summaries, or outputs in the root_dir.
|
agent
|
tf_agent.TFAgent instance to train with.
|
experience_dataset_fn
|
a function that will create an instance of a
tf.data.Dataset used to sample experience for training. Required for
using the Learner as is. Optional for subclass learners which take a new
iterator each time when learner.run is called.
|
after_train_strategy_step_fn
|
(Optional) callable of the form fn(sample,
loss) which can be used for example to update priorities in a replay
buffer where sample is pulled from the experience_iterator and loss is
a LossInfo named tuple returned from the agent. This is called after
every train step. It runs using strategy.run(...) .
|
triggers
|
List of callables of the form trigger(train_step) . After every
run call every trigger is called with the current train_step value
as an np scalar.
|
checkpoint_interval
|
Number of train steps in between checkpoints. Note
these are placed into triggers and so a check to generate a checkpoint
only occurs after every run call. Set to -1 to disable (this is not
recommended, because it means that if the pipeline gets preempted, all
previous progress is lost). This only takes care of the checkpointing
the training process. Policies must be explicitly exported through
triggers.
|
summary_interval
|
Number of train steps in between summaries. Note these
are placed into triggers and so a check to generate a checkpoint only
occurs after every run call.
|
max_checkpoints_to_keep
|
Maximum number of checkpoints to keep around.
These are used to recover from pre-emptions when training.
|
use_kwargs_in_agent_train
|
If True the experience from the replay buffer
is passed into the agent as kwargs. This requires samples from the RB to
be of the form dict(experience=experience, kwarg1=kwarg1, ...) . This
is useful if you have an agent with a custom argspec.
|
strategy
|
(Optional) tf.distribute.Strategy to use during training.
|
run_optimizer_variable_init
|
Specifies if the variables of the optimizer
are initialized before checkpointing. This should be almost always
True (default) to ensure that the state of the optimizer is
checkpointed properly. The initialization of the optimizer variables
happens by building the Tensorflow graph. This is done by calling a
get_concrete_function on the agent's train method which requires
passing some input. Since, no real data is available at this point we
use the batched form of training_data_spec to achieve this (standard
technique). The problem arises when the agent expects some agent
specific batching of the input. In this case, there is no general way
at this point in the learner to batch the impacted specs properly. To
avoid breaking the code in these specific cases, we recommend turning
off initialization of the optimizer variables by setting the value of
this field to False .
|
use_reverb_v2
|
If True then we expect the dataset samples to return a
named_tuple with a data and an info field. If False we expect a
tuple(data, info).
|
direct_sampling
|
Do not use replay_buffer, but sample from offline dataset
directly.
|
experience_dataset_options
|
(Optional) tf.distribute.InputOptions passed
to strategy.distribute_datasets_from_function , used to control options
on how this dataset is distributed.
|
strategy_run_options
|
(Optional) tf.distribute.RunOptions passed to
strategy.run . This is passed to every strategy.run invocation by the
learner.
|
summary_root_dir
|
(Optional) Root directory path where summaries will be
written to.
|
Attributes |
train_step_numpy
|
The current train_step.
|
Methods
loss
View source
loss(
experience_and_sample_info: Optional[tf_agents.train.learner.ExperienceAndSampleInfo
] = None,
reduce_op: tf.distribute.ReduceOp = tf.distribute.ReduceOp.SUM
) -> tf_agents.agents.tf_agent.LossInfo
Computes loss for the experience.
Since this calls agent.loss() it does not update gradients or
increment the train step counter. Networks are called with training=False
so statistics like batch norm are not updated.
Args |
experience_and_sample_info
|
A batch of experience and sample info. If not
specified, next(self._experience_iterator) is used.
|
reduce_op
|
a tf.distribute.ReduceOp value specifying how loss values
should be aggregated across replicas.
|
Returns |
The total loss computed.
|
run
View source
run(
iterations=1, iterator=None, parallel_iterations=10
)
Runs iterations
iterations of training.
Args |
iterations
|
Number of train iterations to perform per call to run. The
iterations will be evaluated in a tf.while loop created by autograph.
Final aggregated losses will be returned.
|
iterator
|
The iterator to the dataset to use for training. If not
specified, self._experience_iterator is used.
|
parallel_iterations
|
Maximum number of train iterations to allow running
in parallel. This value is forwarded directly to the training tf.while
loop.
|
Returns |
The total loss computed before running the final step.
|
single_train_step
View source
single_train_step(
iterator
)