モデルの平均化

TensorFlow.orgで 表示 Run in Google Colab GitHub でソースを表示 ノートブックをダウンロード

概要

このノートブックでは、TensorFlow Addons パッケージから移動平均オプティマイザとモデル平均チェックポイントを使用する方法を紹介します。

移動平均化

移動平均化の利点は、最新のバッチで激しい損失の変化や不規則なデータ表現を発生させにくいことです。ある時点までのモデルのトレーニングがスムーズになり、より一般的なアイデアを提供します。

確率的平均化

確率的重み平均化は、より広いオプティマイザに収束します。これは幾何学的なアンサンブルに似ています。確率的重み平均化は、他のオプティマイザのラッパーとして使用し、内側のオプティマイザのトラジェクトリの異なる点からの結果を平均化することでモデルの性能を向上させる、シンプルな方法です。

モデル平均チェックポイント

callbacks.ModelCheckpointにはトレーニングの途中で移動平均の重みを保存するオプションがないため、モデル平均オプティマイザにはカスタムコールバックが必要でした。update_weightsパラメータを使用すると、ModelAverageCheckpointで以下が可能になります。

  1. モデルに移動平均重みを割り当てて保存する。
  2. 古い平均化されていない重みはそのままにして、保存されたモデルは平均化された重みを使用する。

セットアップ

pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
import os

モデルを構築する

def create_model(opt):
    model = tf.keras.models.Sequential([
        tf.keras.layers.Flatten(),                         
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer=opt,
                    loss='sparse_categorical_crossentropy',
                    metrics=['accuracy'])

    return model

データセットを準備する

#Load Fashion MNIST dataset
train, test = tf.keras.datasets.fashion_mnist.load_data()

images, labels = train
images = images/255.0
labels = labels.astype(np.int32)

fmnist_train_ds = tf.data.Dataset.from_tensor_slices((images, labels))
fmnist_train_ds = fmnist_train_ds.shuffle(5000).batch(32)

test_images, test_labels = test

ここでは、次の 3 つのオプティマイザを比較してみます。

  • ラップされていない SGD
  • 移動平均を適用した SGD
  • 確率的重み平均を適用した SGD

同じモデルを使用してパフォーマンスを見てみましょう。

#Optimizers 
sgd = tf.keras.optimizers.SGD(0.01)
moving_avg_sgd = tfa.optimizers.MovingAverage(sgd)
stocastic_avg_sgd = tfa.optimizers.SWA(sgd)

MovingAverageオプティマイザとStocasticAverageオプティマイザは、どちらもModelAverageCheckpointを使用します。

#Callback 
checkpoint_path = "./training/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_dir,
                                                 save_weights_only=True,
                                                 verbose=1)
avg_callback = tfa.callbacks.AverageModelCheckpoint(filepath=checkpoint_dir, 
                                                    update_weights=True)

モデルをトレーニングする

Vanilla SGD オプティマイザ

#Build Model
model = create_model(sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[cp_callback])
Epoch 1/5
1875/1875 [==============================] - 5s 2ms/step - loss: 1.1251 - accuracy: 0.6304

Epoch 00001: saving model to ./training
Epoch 2/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.5238 - accuracy: 0.8171

Epoch 00002: saving model to ./training
Epoch 3/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4629 - accuracy: 0.8380

Epoch 00003: saving model to ./training
Epoch 4/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4359 - accuracy: 0.8467

Epoch 00004: saving model to ./training
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.4145 - accuracy: 0.8545

Epoch 00005: saving model to ./training
<tensorflow.python.keras.callbacks.History at 0x7f64ac74e898>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 1s - loss: 83.8808 - accuracy: 0.7944
Loss : 83.88079833984375
Accuracy : 0.7943999767303467

移動平均 SGD

#Build Model
model = create_model(moving_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5
1875/1875 [==============================] - 5s 2ms/step - loss: 1.0931 - accuracy: 0.6498
INFO:tensorflow:Assets written to: ./training/assets
Epoch 2/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.5204 - accuracy: 0.8193
INFO:tensorflow:Assets written to: ./training/assets
Epoch 3/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.4701 - accuracy: 0.8352
INFO:tensorflow:Assets written to: ./training/assets
Epoch 4/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.4371 - accuracy: 0.8474
INFO:tensorflow:Assets written to: ./training/assets
Epoch 5/5
1875/1875 [==============================] - 5s 2ms/step - loss: 0.4164 - accuracy: 0.8548
INFO:tensorflow:Assets written to: ./training/assets
<tensorflow.python.keras.callbacks.History at 0x7f64ac544da0>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 83.8808 - accuracy: 0.7944
Loss : 83.88079833984375
Accuracy : 0.7943999767303467

確率的重み平均 SGD

#Build Model
model = create_model(stocastic_avg_sgd)

#Train the network
model.fit(fmnist_train_ds, epochs=5, callbacks=[avg_callback])
Epoch 1/5
1875/1875 [==============================] - 6s 3ms/step - loss: 1.0524 - accuracy: 0.6586
INFO:tensorflow:Assets written to: ./training/assets
Epoch 2/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5922 - accuracy: 0.7989
INFO:tensorflow:Assets written to: ./training/assets
Epoch 3/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5485 - accuracy: 0.8112
INFO:tensorflow:Assets written to: ./training/assets
Epoch 4/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5288 - accuracy: 0.8184
INFO:tensorflow:Assets written to: ./training/assets
Epoch 5/5
1875/1875 [==============================] - 5s 3ms/step - loss: 0.5147 - accuracy: 0.8205
INFO:tensorflow:Assets written to: ./training/assets
<tensorflow.python.keras.callbacks.History at 0x7f64a0186c18>
#Evalute results
model.load_weights(checkpoint_dir)
loss, accuracy = model.evaluate(test_images, test_labels, batch_size=32, verbose=2)
print("Loss :", loss)
print("Accuracy :", accuracy)
313/313 - 0s - loss: 83.8808 - accuracy: 0.7944
Loss : 83.88079833984375
Accuracy : 0.7943999767303467