View source on GitHub |
A helper for use during inference.
Inherits From: Helper
tf.contrib.seq2seq.GreedyEmbeddingHelper(
embedding, start_tokens, end_token
)
Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.
Args | |
---|---|
embedding
|
A callable that takes a vector tensor of ids (argmax ids),
or the params argument for embedding_lookup . The returned tensor
will be passed to the decoder input.
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is not a
scalar.
|
Attributes | |
---|---|
batch_size
|
Batch size of tensor returned by sample .
Returns a scalar int32 tensor. |
sample_ids_dtype
|
DType of tensor returned by sample .
Returns a DType. |
sample_ids_shape
|
Shape of tensor returned by sample , excluding the batch dimension.
Returns a |
Methods
initialize
initialize(
name=None
)
Returns (initial_finished, initial_inputs)
.
next_inputs
next_inputs(
time, outputs, state, sample_ids, name=None
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state, name=None
)
sample for GreedyEmbeddingHelper.