View source on GitHub |
A inference sampler that takes the maximum from the output distribution.
Inherits From: Sampler
tfa.seq2seq.GreedyEmbeddingSampler(
embedding_fn: Optional[Callable] = None
)
Used in the notebooks
Used in the tutorials |
---|
Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.
Args | |
---|---|
embedding_fn
|
A optional callable that takes a vector tensor of ids
(argmax ids). The returned tensor will be passed to the decoder
input. Default to use tf.nn.embedding_lookup .
|
Methods
initialize
initialize(
embedding, start_tokens=None, end_token=None
)
Initialize the GreedyEmbeddingSampler.
Args | |
---|---|
embedding
|
tensor that contains embedding states matrix. It will be
used to generate generate outputs with start_tokens and end_token .
The embedding will be ignored if the embedding_fn has been provided
at init().
|
start_tokens
|
int32 vector shaped [batch_size] , the start tokens.
|
end_token
|
int32 scalar, the token that marks end of decoding.
|
Returns | |
---|---|
Tuple of two items: (finished, self.start_inputs) .
|
Raises | |
---|---|
ValueError
|
if start_tokens is not a 1D tensor or end_token is
not a scalar.
|
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
next_inputs_fn for GreedyEmbeddingHelper.
sample
sample(
time, outputs, state
)
sample for GreedyEmbeddingHelper.