View source on GitHub |
A training sampler that simply reads its inputs.
Inherits From: Sampler
tfa.seq2seq.TrainingSampler(
time_major: bool = False
)
Returned sample_ids are the argmax of the RNN output logits.
Args | |
---|---|
time_major
|
Python bool. Whether the tensors in inputs are time
major. If False (default), they are assumed to be batch major.
|
Raises | |
---|---|
ValueError
|
if sequence_length is not a 1D tensor or mask is
not a 2D boolean tensor.
|
Methods
initialize
initialize(
inputs, sequence_length=None, mask=None
)
Initialize the TrainSampler.
Args | |
---|---|
inputs
|
A (structure of) input tensors. |
sequence_length
|
An int32 vector tensor. |
mask
|
A boolean 2D tensor. |
Returns | |
---|---|
(finished, next_inputs), a tuple of two items. The first item is a boolean vector to indicate whether the item in the batch has finished. The second item is the first slide of input data based on the timestep dimension (usually the second dim of the input). |
next_inputs
next_inputs(
time, outputs, state, sample_ids
)
Returns (finished, next_inputs, next_state)
.
sample
sample(
time, outputs, state
)
Returns sample_ids
.