Configuration for the "train" part for the train_and_evaluate
call.
tf.estimator.TrainSpec(
input_fn, max_steps=None, hooks=None, saving_listeners=None
)
TrainSpec
determines the input data for the training, as well as the
duration. Optional hooks run at various stages of training.
Usage:
train_spec = tf.estimator.TrainSpec(
input_fn=lambda: 1,
max_steps=100,
hooks=[_StopAtSecsHook(stop_after_secs=10)],
saving_listeners=[_NewCheckpointListenerForEvaluate(None, 20, None)])
train_spec.saving_listeners[0]._eval_throttle_secs
20
train_spec.hooks[0]._stop_after_secs
10
train_spec.max_steps
100
Args |
input_fn
|
A function that provides input data for training as minibatches.
See Premade Estimators
for more information. The function should construct and return one of
the following:
- A 'tf.data.Dataset' object: Outputs of
Dataset object must be a
tuple (features, labels) with same constraints as below.
- A tuple (features, labels): Where features is a
Tensor or a
dictionary of string feature name to Tensor and labels is a
Tensor or a dictionary of string label name to Tensor .
|
max_steps
|
Int. Positive number of total steps for which to train model.
If None , train forever. The training input_fn is not expected to
generate OutOfRangeError or StopIteration exceptions. See the
train_and_evaluate stop condition section for details.
|
hooks
|
Iterable of tf.train.SessionRunHook objects to run on all workers
(including chief) during training.
|
saving_listeners
|
Iterable of tf.estimator.CheckpointSaverListener
objects to run on chief during training.
|
Raises |
ValueError
|
If any of the input arguments is invalid.
|
TypeError
|
If any of the arguments is not of the expected type.
|
Attributes |
input_fn
|
A namedtuple alias for field number 0
|
max_steps
|
A namedtuple alias for field number 1
|
hooks
|
A namedtuple alias for field number 2
|
saving_listeners
|
A namedtuple alias for field number 3
|