Devoluciones de llamada de los complementos de TensorFlow: TimeStopping

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Visión general

Este cuaderno demostrará cómo usar la devolución de llamada TimeStopping en los complementos de TensorFlow.

Configuración

pip install -q -U tensorflow-addons
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten

Importar y normalizar datos

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 0s 0us/step

Cree un modelo CNN de MNIST simple

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Uso simple de interrupción del tiempo

# initialize TimeStopping callback 
time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=100,
          callbacks=[time_stopping_callback],
          validation_data=(x_test, y_test))
Epoch 1/100
938/938 [==============================] - 3s 3ms/step - loss: 0.5649 - accuracy: 0.8378 - val_loss: 0.1624 - val_accuracy: 0.9548
Epoch 2/100
938/938 [==============================] - 2s 2ms/step - loss: 0.1684 - accuracy: 0.9514 - val_loss: 0.1160 - val_accuracy: 0.9653
Timed stopping at epoch 2 after training for 0:00:05
<tensorflow.python.keras.callbacks.History at 0x7f3b947672b0>