View source on GitHub |
Interface for implementing sampling in seq2seq decoders.
Sampler classes implement the logic of sampling from the decoder output distribution
and producing the inputs for the next decoding step. In most cases, they should not be
used directly but passed to a tfa.seq2seq.BasicDecoder
instance that will manage the
sampling.
Here is an example using a training sampler directly to implement a custom decoding loop:
batch_size = 4
max_time = 7
hidden_size = 16
sampler = tfa.seq2seq.TrainingSampler()
cell = tf.keras.layers.LSTMCell(hidden_size)
input_tensors = tf.random.uniform([batch_size, max_time, hidden_size])
initial_finished, initial_inputs = sampler.initialize(input_tensors)
cell_input = initial_inputs
cell_state = cell.get_initial_state(initial_inputs)
for time_step in tf.range(max_time):
cell_output, cell_state = cell(cell_input, cell_state)
sample_ids = sampler.sample(time_step, cell_output, cell_state)
finished, cell_input, cell_state = sampler.next_inputs(
time_step, cell_output, cell_state, sample_ids)
if tf.reduce_all(finished):
break
Methods
initialize
@abc.abstractmethod
initialize( inputs, **kwargs )
initialize the sampler with the input tensors.
This method must be invoked exactly once before calling other methods of the Sampler.
Args | |
---|---|
inputs
|
A (structure of) input tensors, it could be a nested tuple or a single tensor. |
**kwargs
|
Other kwargs for initialization. It could contain tensors like mask for inputs, or non tensor parameter. |
Returns | |
---|---|
(initial_finished, initial_inputs) .
|
next_inputs
@abc.abstractmethod
next_inputs( time, outputs, state, sample_ids )
Returns (finished, next_inputs, next_state)
.
sample
@abc.abstractmethod
sample( time, outputs, state )
Returns sample_ids
.