TensorFlow.orgで 表示 | Run in Google Colab | GitHub でソースを表示 | ノートブックをダウンロード |
概要
このノートブックでは、TensorFlow Addons パッケージから移動平均オプティマイザとモデル平均チェックポイントを使用する方法を紹介します。
移動平均化
移動平均化の利点は、最新のバッチで激しい損失の変化や不規則なデータ表現を発生させにくいことです。ある時点までのモデルのトレーニングがスムーズになり、より一般的なアイデアを提供します。
確率的平均化
確率的重み平均化は、より広いオプティマイザに収束します。これは幾何学的なアンサンブルに似ています。確率的重み平均化は、他のオプティマイザのラッパーとして使用し、内側のオプティマイザのトラジェクトリの異なる点からの結果を平均化することでモデルの性能を向上させる、シンプルな方法です。
モデル平均チェックポイント
callbacks.ModelCheckpoint
にはトレーニングの途中で移動平均の重みを保存するオプションがないため、モデル平均オプティマイザにはカスタムコールバックが必要でした。update_weights
パラメータを使用すると、ModelAverageCheckpoint
で以下が可能になります。
- モデルに移動平均重みを割り当てて保存する。
- 古い平均化されていない重みはそのままにして、保存されたモデルは平均化された重みを使用する。
セットアップ
pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os
モデルを構築する
def create_model(opt):
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer=opt,
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
データセットを準備する
#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()
images, labels = train
images = images/255.0
labels = labels.astype(np.int32)
fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)
test_images, test_labels = test
ここでは、次の 3 つのオプティマイザを比較してみます。
- ラップされていない SGD
- 移動平均を適用した SGD
- 確率的重み平均を適用した SGD
同じモデルを使用してパフォーマンスを見てみましょう。
#Optimizers
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)
MovingAverage
オプティマイザとStocasticAverage
オプティマイザは、どちらもModelAverageCheckpoint
を使用します。
#Callback
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
save_weights_only=True,
verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir,
update_weights=True)
モデルをトレーニングする
Vanilla SGD オプティマイザ
#Build Model
model = create_model(sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5 1875/1875 [==============================] - 5s 2ms/step - loss: 1.1251 - accuracy: 0.6304 Epoch 00001: saving model to ./training Epoch 2/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.5238 - accuracy: 0.8171 Epoch 00002: saving model to ./training Epoch 3/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4629 - accuracy: 0.8380 Epoch 00003: saving model to ./training Epoch 4/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4359 - accuracy: 0.8467 Epoch 00004: saving model to ./training Epoch 5/5 1875/1875 [==============================] - 4s 2ms/step - loss: 0.4145 - accuracy: 0.8545 Epoch 00005: saving model to ./training <tensorflow.python.keras.callbacks.History at 0x7f64ac74e898>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 83.8808 - accuracy: 0.7944 Loss : 83.88079833984375 Accuracy : 0.7943999767303467
移動平均 SGD
#Build Model
model = create_model(moving_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1875/1875 [==============================] - 5s 2ms/step - loss: 1.0931 - accuracy: 0.6498 INFO:tensorflow:Assets written to: ./training/assets Epoch 2/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.5204 - accuracy: 0.8193 INFO:tensorflow:Assets written to: ./training/assets Epoch 3/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4701 - accuracy: 0.8352 INFO:tensorflow:Assets written to: ./training/assets Epoch 4/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4371 - accuracy: 0.8474 INFO:tensorflow:Assets written to: ./training/assets Epoch 5/5 1875/1875 [==============================] - 5s 2ms/step - loss: 0.4164 - accuracy: 0.8548 INFO:tensorflow:Assets written to: ./training/assets <tensorflow.python.keras.callbacks.History at 0x7f64ac544da0>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 83.8808 - accuracy: 0.7944 Loss : 83.88079833984375 Accuracy : 0.7943999767303467
確率的重み平均 SGD
#Build Model
model = create_model(stocastic_avg_sgd)
#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5 1875/1875 [==============================] - 6s 3ms/step - loss: 1.0524 - accuracy: 0.6586 INFO:tensorflow:Assets written to: ./training/assets Epoch 2/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5922 - accuracy: 0.7989 INFO:tensorflow:Assets written to: ./training/assets Epoch 3/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5485 - accuracy: 0.8112 INFO:tensorflow:Assets written to: ./training/assets Epoch 4/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5288 - accuracy: 0.8184 INFO:tensorflow:Assets written to: ./training/assets Epoch 5/5 1875/1875 [==============================] - 5s 3ms/step - loss: 0.5147 - accuracy: 0.8205 INFO:tensorflow:Assets written to: ./training/assets <tensorflow.python.keras.callbacks.History at 0x7f64a0186c18>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 83.8808 - accuracy: 0.7944 Loss : 83.88079833984375 Accuracy : 0.7943999767303467