Ops and objects returned from a model_fn
and passed to an Estimator
.
tf.estimator.EstimatorSpec(
mode,
predictions=None,
loss=None,
train_op=None,
eval_metric_ops=None,
export_outputs=None,
training_chief_hooks=None,
training_hooks=None,
scaffold=None,
evaluation_hooks=None,
prediction_hooks=None
)
EstimatorSpec
fully defines the model to be run by an Estimator
.
Args |
mode
|
A ModeKeys . Specifies if this is training, evaluation or
prediction.
|
predictions
|
Predictions Tensor or dict of Tensor .
|
loss
|
Training loss Tensor . Must be either scalar, or with shape [1] .
|
train_op
|
Op for the training step.
|
eval_metric_ops
|
Dict of metric results keyed by name.
The values of the dict can be one of the following: (1) instance of
Metric class. (2) Results of calling a metric function, namely a
(metric_tensor, update_op) tuple. metric_tensor should be
evaluated without any impact on state (typically is a pure computation
results based on variables.). For example, it should not trigger the
update_op or requires any input fetching.
|
export_outputs
|
Describes the output signatures to be exported to
SavedModel and used during serving.
A dict {name: output} where:
- name: An arbitrary name for this output.
- output: an
ExportOutput object such as ClassificationOutput ,
RegressionOutput , or PredictOutput . Single-headed models only need
to specify one entry in this dictionary. Multi-headed models should
specify one entry for each head, one of which must be named using
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY .
If no entry is provided, a default PredictOutput mapping to
predictions will be created.
|
training_chief_hooks
|
Iterable of tf.train.SessionRunHook objects to run
on the chief worker during training.
|
training_hooks
|
Iterable of tf.train.SessionRunHook objects to run on
all workers during training.
|
scaffold
|
A tf.train.Scaffold object that can be used to set
initialization, saver, and more to be used in training.
|
evaluation_hooks
|
Iterable of tf.train.SessionRunHook objects to run
during evaluation.
|
prediction_hooks
|
Iterable of tf.train.SessionRunHook objects to run
during predictions.
|
Raises |
ValueError
|
If validation fails.
|
TypeError
|
If any of the arguments is not the expected type.
|
Attributes |
mode
|
A namedtuple alias for field number 0
|
predictions
|
A namedtuple alias for field number 1
|
loss
|
A namedtuple alias for field number 2
|
train_op
|
A namedtuple alias for field number 3
|
eval_metric_ops
|
A namedtuple alias for field number 4
|
export_outputs
|
A namedtuple alias for field number 5
|
training_chief_hooks
|
A namedtuple alias for field number 6
|
training_hooks
|
A namedtuple alias for field number 7
|
scaffold
|
A namedtuple alias for field number 8
|
evaluation_hooks
|
A namedtuple alias for field number 9
|
prediction_hooks
|
A namedtuple alias for field number 10
|