Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir la source sur GitHub | Télécharger le cahier |
Aperçu
Ce bloc-notes montre comment utiliser l'optimiseur de moyenne mobile avec le point de contrôle de la moyenne du modèle du package d'extensions tensorflow.
Moyenne mobile
L'avantage de la moyenne mobile est qu'elle est moins sujette aux décalages de perte endémiques ou à la représentation irrégulière des données dans le dernier lot. Cela donne une idée lissée et plus générale de la formation du modèle jusqu'à un certain point.
Moyenne stochastique
La moyenne des poids stochastiques converge vers des optima plus larges. Ce faisant, il ressemble à un assemblage géométrique. SWA est une méthode simple pour améliorer les performances du modèle lorsqu'il est utilisé comme enveloppe autour d'autres optimiseurs et en faisant la moyenne des résultats à partir de différents points de trajectoire de l'optimiseur interne.
Point de contrôle moyen du modèle
callbacks.ModelCheckpoint
ne vous donne pas la possibilité d'enregistrer le déplacement des poids moyens au milieu de la formation, ce qui explique pourquoi Modèle moyen optimiseurs nécessaire un rappel personnalisé. En utilisant leupdate_weights
paramètre,ModelAverageCheckpoint
vous permet de:
- Attribuez les poids moyens mobiles au modèle et enregistrez-les.
- Conservez les anciens poids non moyens, mais le modèle enregistré utilise les poids moyens.
Installer
pip install -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os
Construire le modèle
def create_model(opt):
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer=opt,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
Préparer l'ensemble de données
#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)
test_images, test_labels = test
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 32768/29515 [=================================] - 0s 0us/step 40960/29515 [=========================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26427392/26421880 [==============================] - 0s 0us/step 26435584/26421880 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 16384/5148 [===============================================================================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4423680/4422102 [==============================] - 0s 0us/step 4431872/4422102 [==============================] - 0s 0us/step
Nous allons comparer trois optimiseurs ici :
- SGD non emballé
- SGD avec moyenne mobile
- SGD avec pondération stochastique
Et voyez comment ils fonctionnent avec le même modèle.
#Optimizers
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)
Les deux MovingAverage
et StocasticAverage
optimers utilisent ModelAverageCheckpoint
.
#Callback
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
save_weights_only=True,
verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir,
update_weights=True)
Modèle de train
Optimiseur de vanille SGD
#Build Model
model = create_model(sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.8031 - accuracy: 0.7282 Epoch 00001: saving model to ./training Epoch 2/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.5049 - accuracy: 0.8240 Epoch 00002: saving model to ./training Epoch 3/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.4591 - accuracy: 0.8375 Epoch 00003: saving model to ./training Epoch 4/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.4328 - accuracy: 0.8492 Epoch 00004: saving model to ./training Epoch 5/5 1875/1875 [==============================] - 3s 2ms/step - loss: 0.4128 - accuracy: 0.8561 Epoch 00005: saving model to ./training <keras.callbacks.History at 0x7fc9d0262250>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 95.4645 - accuracy: 0.7796 Loss : 95.46446990966797 Accuracy : 0.7796000242233276
Moyenne mobile SGD
#Build Model
model = create_model(moving_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.8064 - accuracy: 0.7303 2021-09-02 00:35:29.787996: W tensorflow/python/util/util.cc:348] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: ./training/assets Epoch 2/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.5114 - accuracy: 0.8223 INFO:tensorflow:Assets written to: ./training/assets Epoch 3/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4620 - accuracy: 0.8382 INFO:tensorflow:Assets written to: ./training/assets Epoch 4/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4345 - accuracy: 0.8470 INFO:tensorflow:Assets written to: ./training/assets Epoch 5/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4146 - accuracy: 0.8547 INFO:tensorflow:Assets written to: ./training/assets <keras.callbacks.History at 0x7fc8e16f30d0>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 95.4645 - accuracy: 0.7796 Loss : 95.46446990966797 Accuracy : 0.7796000242233276
Moyenne du poids stocastique SGD
#Build Model
model = create_model(stocastic_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.7896 - accuracy: 0.7350 INFO:tensorflow:Assets written to: ./training/assets Epoch 2/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.5670 - accuracy: 0.8065 INFO:tensorflow:Assets written to: ./training/assets Epoch 3/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.5345 - accuracy: 0.8142 INFO:tensorflow:Assets written to: ./training/assets Epoch 4/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.5194 - accuracy: 0.8188 INFO:tensorflow:Assets written to: ./training/assets Epoch 5/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.5089 - accuracy: 0.8235 INFO:tensorflow:Assets written to: ./training/assets <keras.callbacks.History at 0x7fc8e0538790>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 95.4645 - accuracy: 0.7796 Loss : 95.46446990966797 Accuracy : 0.7796000242233276