TensorFlow 애드온 레이어: WeightNormalization

TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서 소스 보기 노트북 다운로드하기

개요

이 노트북은 가중치 정규화 레이어를 사용하는 방법과 수렴을 향상할 수 있는 방법을 보여줍니다.

WeightNormalization

심층 신경망의 훈련을 가속하기 위한 간단한 재매개변수화:

Tim Salimans, Diederik P. Kingma (2016)

이러한 방식으로 가중치를 재매개변수화함으로써 최적화 문제의 처리를 개선하고 확률적 경사 하강의 수렴을 가속합니다. 재매개변수화는 배치 정규화에서 영감을 얻었지만, 미니 배치의 예제 간에 종속성을 도입하지는 않습니다. 이는 이 방법이 배치 정규화가 덜 적합한 LSTM과 같은 반복 모델과 심층 강화 학습 또는 생성 모델과 같은 노이즈에 민감한 애플리케이션에 성공적으로 적용될 수 있음을 의미합니다. 이 방법은 훨씬 간단하지만, 전체 배치 정규화의 속도를 크게 향상합니다. 또한, 이 방법의 계산 오버헤드가 더 적으므로 같은 시간에 더 많은 최적화 단계를 수행할 수 있습니다.

https://arxiv.org/abs/1602.07868



설정

pip install -q -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
from matplotlib import pyplot as plt
# Hyper Parameters
batch_size = 32
epochs = 10
num_classes=10

모델 빌드하기

# Standard ConvNet
reg_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(6, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(16, 5, activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='relu'),
    tf.keras.layers.Dense(84, activation='relu'),
    tf.keras.layers.Dense(num_classes, activation='softmax'),
])
# WeightNorm ConvNet
wn_model = tf.keras.Sequential([
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(6, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tfa.layers.WeightNormalization(tf.keras.layers.Conv2D(16, 5, activation='relu')),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(120, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(84, activation='relu')),
    tfa.layers.WeightNormalization(tf.keras.layers.Dense(num_classes, activation='softmax')),
])

데이터 로드하기

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170500096/170498071 [==============================] - 11s 0us/step

모델 훈련하기

reg_model.compile(optimizer='adam', 
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

reg_history = reg_model.fit(x_train, y_train,
                            batch_size=batch_size,
                            epochs=epochs,
                            validation_data=(x_test, y_test),
                            shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.6373 - accuracy: 0.4014 - val_loss: 1.4192 - val_accuracy: 0.4774
Epoch 2/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.3673 - accuracy: 0.5106 - val_loss: 1.3043 - val_accuracy: 0.5369
Epoch 3/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.2437 - accuracy: 0.5571 - val_loss: 1.2662 - val_accuracy: 0.5441
Epoch 4/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.1561 - accuracy: 0.5916 - val_loss: 1.1837 - val_accuracy: 0.5858
Epoch 5/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.0962 - accuracy: 0.6118 - val_loss: 1.1664 - val_accuracy: 0.5898
Epoch 6/10
1563/1563 [==============================] - 4s 3ms/step - loss: 1.0444 - accuracy: 0.6311 - val_loss: 1.1396 - val_accuracy: 0.6047
Epoch 7/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.9957 - accuracy: 0.6496 - val_loss: 1.1266 - val_accuracy: 0.6101
Epoch 8/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.9555 - accuracy: 0.6633 - val_loss: 1.1521 - val_accuracy: 0.6028
Epoch 9/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.9167 - accuracy: 0.6772 - val_loss: 1.1309 - val_accuracy: 0.6132
Epoch 10/10
1563/1563 [==============================] - 4s 3ms/step - loss: 0.8779 - accuracy: 0.6891 - val_loss: 1.1575 - val_accuracy: 0.6045
wn_model.compile(optimizer='adam', 
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

wn_history = wn_model.fit(x_train, y_train,
                          batch_size=batch_size,
                          epochs=epochs,
                          validation_data=(x_test, y_test),
                          shuffle=True)
Epoch 1/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.5913 - accuracy: 0.4222 - val_loss: 1.4270 - val_accuracy: 0.4791
Epoch 2/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.3218 - accuracy: 0.5244 - val_loss: 1.2732 - val_accuracy: 0.5428
Epoch 3/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.2096 - accuracy: 0.5673 - val_loss: 1.2939 - val_accuracy: 0.5403
Epoch 4/10
1563/1563 [==============================] - 8s 5ms/step - loss: 1.1279 - accuracy: 0.5982 - val_loss: 1.1730 - val_accuracy: 0.5866
Epoch 5/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.0682 - accuracy: 0.6190 - val_loss: 1.1559 - val_accuracy: 0.5880
Epoch 6/10
1563/1563 [==============================] - 7s 5ms/step - loss: 1.0139 - accuracy: 0.6412 - val_loss: 1.1333 - val_accuracy: 0.6025
Epoch 7/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.9641 - accuracy: 0.6584 - val_loss: 1.1257 - val_accuracy: 0.6100
Epoch 8/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.9190 - accuracy: 0.6736 - val_loss: 1.1275 - val_accuracy: 0.6057
Epoch 9/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.8774 - accuracy: 0.6897 - val_loss: 1.1257 - val_accuracy: 0.6120
Epoch 10/10
1563/1563 [==============================] - 8s 5ms/step - loss: 0.8379 - accuracy: 0.7040 - val_loss: 1.1223 - val_accuracy: 0.6183
reg_accuracy = reg_history.history['accuracy']
wn_accuracy = wn_history.history['accuracy']

plt.plot(np.linspace(0, epochs,  epochs), reg_accuracy,
             color='red', label='Regular ConvNet')

plt.plot(np.linspace(0, epochs, epochs), wn_accuracy,
         color='blue', label='WeightNorm ConvNet')

plt.title('WeightNorm Accuracy Comparison')
plt.legend()
plt.grid(True)
plt.show()

png