
Steps per execution tuner class.

optimizer The optimizer used for training/evaluation/prediction. Used to measure iterations and global throughput (optimizer.iterations/second).
spe_variable A tf.Variable representing the steps_per_execution variable used during training/evaluation/prediction. Must be updatable with spe_variable.assign.
interval Optional int, the amount of seconds to wait between calls to measure throughput and tune spe_variable. Defaults to 5.
change_spe_interval Optional int, the number of throughput measurements before tuning. Defaults to 10.
change_threshold Optional float, the percent different in throughput to trigger a steps_per_execution change. For example, 0.1 triggers changes if throughput changes more than 10%.


If you're using model.compile and, this functionality is available at compile time with steps_per_execution='auto'

model.compile(..., steps_per_execution='auto')

Custom training loop usage:

# Get model
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (_, _) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
train_dataset =, y_train))

# Create our steps per execution variable
steps_per_execution = tf.Variable(

# Create the tuner
tuner = StepsPerExecutionTuner(
    optimizer, steps_per_execution

# Create a step function that runs a single training step
def step_fn(iterator):
    batch_data, labels = next(iterator)
    with tf.GradientTape() as tape:
        logits = model(batch_data, training=True)
        loss_value = loss_fn(labels, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))

# We can now pack multiple execution steps into one call
def multi_step_train_fn(iterator, steps_per_execution):
    for _ in tf.range(steps_per_execution):
        outputs = step_fn(iterator)

initial_steps_per_execution = 1
steps_per_epoch = 100
epochs = 2

# Start the tuner before training

# We can now call our multi step training with our data
for epoch in range(epochs):
    for _ in range(steps_per_epoch):
        multi_step_train_fn(iterator, steps_per_execution)

# End the tuner after training

steps_per_execution Settable attribute representingsteps_per_execution variable.



Starts steps per execution tuning thread.

Returns a threading.Thread which will run every self.interval seconds to measure throughput and tune steps per execution.


Stops steps per execution tuning thread.