TensorFlow.orgで表示 | GoogleColabで実行 | GitHubでソースを表示 | ノートブックをダウンロード |
概要
このノートブックは簡単に紹介を与える正規化層TensorFlowの。現在サポートされているレイヤーは次のとおりです。
- グループの正規化(TensorFlowアドオン)
- インスタンスの正規化(TensorFlowアドオン)
- 層の正規化(TensorFlowコア)
これらのレイヤーの背後にある基本的な考え方は、トレーニング中の収束を改善するためにアクティベーションレイヤーの出力を正規化することです。対照的に、バッチ正規これらの正規化ではなく、彼らはだけでなく、再発neualネットワークに適してい、単一のサンプルのアクティベーションを正常化、バッチの作業をしないでください。
通常、正規化は、入力テンソルのサブグループの平均と標準偏差を計算することによって実行されます。これにスケールとオフセット係数を適用することも可能です。
\(y_{i} = \frac{\gamma ( x_{i} - \mu )}{\sigma }+ \beta\)
\( y\) :出力
\(x\) :入力
\(\gamma\) :スケールファクタ
\(\mu\):平均
\(\sigma\):標準偏差
\(\beta\):オフセット要因
次の画像は、これらの手法の違いを示しています。各サブプロットは、入力テンソルを示しています。Nはバッチ軸、Cはチャネル軸、(H、W)は空間軸(たとえば、画像の高さと幅)です。青のピクセルは、これらのピクセルの値を集計することによって計算された、同じ平均と分散によって正規化されます。
出典:( https://arxiv.org/pdf/1803.08494.pdf )
重みガンマとベータは、表現能力の喪失の可能性を補うために、すべての正規化レイヤーでトレーニング可能です。あなたは、設定することにより、これらの因子を活性化することができますcenter
やscale
にフラグをTrue
。もちろん、あなたが使用することができますinitializers
、 constraints
およびregularizer
のためのbeta
およびgamma
、これらの値は、トレーニングプロセス中にチューニングします。
設定
Tensorflow2.0とTensorflow-アドオンをインストールします
pip install -U tensorflow-addons
import tensorflow as tf
import tensorflow_addons as tfa
データセットの準備
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
グループ正規化チュートリアル
序章
Group Normalization(GN)は、入力のチャネルをより小さなサブグループに分割し、それらの平均と分散に基づいてこれらの値を正規化します。 GNは単一の例で機能するため、この手法はバッチサイズに依存しません。
GNは、画像分類タスクでバッチ正規化に近いスコアを実験的に記録しました。全体的なbatch_sizeが低い場合は、バッチ正規化の代わりにGNを使用すると、バッチ正規化のパフォーマンスが低下する可能性があります。
例
Conv2Dレイヤーの後に10チャンネルを、標準の「チャンネルラスト」設定で5つのサブグループに分割します。
model = tf.keras.models.Sequential([
# Reshape into "channels last" setup.
tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
# Groupnorm Layer
tfa.layers.GroupNormalization(groups=5, axis=3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_test, y_test)
313/313 [==============================] - 3s 3ms/step - loss: 0.4707 - accuracy: 0.8613 <keras.callbacks.History at 0x7f63a5c5f490>
インスタンス正規化チュートリアル
序章
インスタンスの正規化は、グループのサイズがチャネルサイズ(または軸のサイズ)と同じサイズであるグループの正規化の特殊なケースです。
実験結果は、バッチ正規化を置き換える場合、インスタンスの正規化がスタイル転送でうまく機能することを示しています。最近、インスタンスの正規化は、GANのバッチ正規化の代わりとしても使用されています。
例
Conv2Dレイヤーの後にInstanceNormalizationを適用し、均一に初期化されたスケールとオフセット係数を使用します。
model = tf.keras.models.Sequential([
# Reshape into "channels last" setup.
tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
# LayerNorm Layer
tfa.layers.InstanceNormalization(axis=3,
center=True,
scale=True,
beta_initializer="random_uniform",
gamma_initializer="random_uniform"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_test, y_test)
313/313 [==============================] - 1s 3ms/step - loss: 0.5367 - accuracy: 0.8405 <keras.callbacks.History at 0x7f63a58d9f50>
レイヤー正規化チュートリアル
序章
層の正規化は、グループのサイズが1であるグループの正規化の特殊なケースです。平均と標準偏差は、単一のサンプルのすべてのアクティブ化から計算されます。
実験結果は、レイヤーの正規化がバッチサイズに依存せずに機能するため、リカレントニューラルネットワークに適していることを示しています。
例
Conv2Dレイヤーの後にレイヤー正規化を適用し、スケールとオフセット係数を使用します。
model = tf.keras.models.Sequential([
# Reshape into "channels last" setup.
tf.keras.layers.Reshape((28,28,1), input_shape=(28,28)),
tf.keras.layers.Conv2D(filters=10, kernel_size=(3,3),data_format="channels_last"),
# LayerNorm Layer
tf.keras.layers.LayerNormalization(axis=3 , center=True , scale=True),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
model.fit(x_test, y_test)
313/313 [==============================] - 1s 3ms/step - loss: 0.4056 - accuracy: 0.8754 <keras.callbacks.History at 0x7f63a5722d10>