Richiamate dei componenti aggiuntivi di TensorFlow: barra di avanzamento TQDM

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

Panoramica

Questo notebook dimostrerà come utilizzare TQDMCallback in TensorFlow Addons.

Impostare

pip install -U tensorflow-addons
!pip install -q "tqdm>=4.36.1"

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
import tqdm

# quietly deep-reload tqdm
import sys
from IPython.lib import deepreload 

stdout = sys.stdout
sys.stdout = open('junk','w')
deepreload.reload(tqdm)
sys.stdout = stdout

tqdm.__version__
'4.62.3'

Importa e normalizza i dati

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train, x_test = x_train / 255.0, x_test / 255.0

Costruisci un modello CNN MNIST semplice

# build the model using the Sequential API
model = Sequential()
model.add(Flatten(input_shape=(28, 28)))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10, activation='softmax'))

model.compile(optimizer='adam',
              loss = 'sparse_categorical_crossentropy',
              metrics=['accuracy'])

Utilizzo predefinito di TQDMCallback

# initialize tqdm callback with default parameters
tqdm_callback = tfa.callbacks.TQDMProgressBar()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model.fit(x_train, y_train,
          batch_size=64,
          epochs=10,
          verbose=0,
          callbacks=[tqdm_callback],
          validation_data=(x_test, y_test))
Training:   0%|           0/10 ETA: ?s,  ?epochs/s
Epoch 1/10
0/938           ETA: ?s -
Epoch 2/10
0/938           ETA: ?s -
Epoch 3/10
0/938           ETA: ?s -
Epoch 4/10
0/938           ETA: ?s -
Epoch 5/10
0/938           ETA: ?s -
Epoch 6/10
0/938           ETA: ?s -
Epoch 7/10
0/938           ETA: ?s -
Epoch 8/10
0/938           ETA: ?s -
Epoch 9/10
0/938           ETA: ?s -
Epoch 10/10
0/938           ETA: ?s -
<keras.callbacks.History at 0x7f4a8d35aed0>

Di seguito è riportato l'output previsto quando si esegue la cella sopra Figura della barra di avanzamento TQDM

# TQDMProgressBar() also works with evaluate()
model.evaluate(x_test, y_test, batch_size=64, callbacks=[tqdm_callback], verbose=0)
0/157           ETA: ?s - Evaluating
[0.06689586490392685, 0.9805999994277954]

Di seguito è riportato l'output previsto quando si esegue la cella sopra TQDM Valuta la figura della barra di avanzamento