TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
概要
tf.distribute.Strategy
API は、複数の処理ユニットに渡ってトレーニングを分散するための抽象化を提供します。ユーザーは既存のモデルとトレーニングコードを使用して、最小限の変更で分散型トレーニングを実行できるようになります。
このチュートリアルでは、tf.distribute.MirroredStrategy
を使用して、1 台のマシンの多数の GPU で同期トレーニングを行うグラフ内レプリケーションを実行します。ストラテジーは基本的にモデルのすべての変数を各プロセッサにコピーします。その後、all-reduce を使用して全プロセッサからの勾配を結合し、結合された値をモデルの全コピーに適用します。
tf.keras
API を使用して、モデルとそれをトレーニングするための Model.fit
をビルドします。(カスタムトレーニングループと MirroredStrategy
を使った分散型トレーニングについては、こちらのチュートリアルをご覧ください。)
MirroredStrategy
は単一のマシンの複数の GPU でモデルをトレーニングします。複数のワーカーの多数の GPU で同期トレーニングを行う場合は、tf.distribute.MultiWorkerMirroredStrategy
とKeras の Model.fit かカスタムトレーニングループを使用します。その他のオプションについては、分散型トレーニングガイドをご覧ください。
その他のさまざまなストラテジーについては、TensorFlow の分散型トレーニングガイドをご覧ください。
セットアップ
import tensorflow_datasets as tfds
import tensorflow as tf
import os
# Load the TensorBoard notebook extension.
%load_ext tensorboard
2024-01-11 18:18:57.406481: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:18:57.406527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:18:57.408011: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(tf.__version__)
2.15.0
データセットをダウンロードする
TensorFlow Datasets から MNIST データセットを読み込みます。これは、tf.data
形式のデータセットを返します。
with_info
引数を True
に設定すると、データセット全体に対するメタデータが含まれます。ここでは info
に保存されます。このメタデータオブジェクトには、トレーニングとテストの例の数などが含まれます。
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
分散ストラテジーを定義する
MirroredStrategy
オブジェクトを作成します。これは分散を処理し、モデル内に構築するコンテキストマネージャ (MirroredStrategy.scope
) を提供します。
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4
入力パイプラインをセットアップする
マルチ GPU でモデルをトレーニングする場合、バッチサイズを増加することにより追加の計算能力を効果的に利用することができます。一般的には、GPU メモリに収まる最大のバッチサイズを使用し、それに応じて学習率を調整します。
# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
画像ピクセル値を [0, 255]
の範囲から [0, 1]
の範囲に正規化する関数を定義します(特徴量スケーリング)。
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
この scale
関数をトレーニングとテストのデータに適用してから、tf.data.Dataset
API を使用してトレーニングデータをシャッフル(Dataset.shuffle
)し、バッチ化(Dataset.batch
)します。パフォーマンスを改善するために、トレーニングデータのインメモリキャッシュも保持していることに注意してください(Dataset.cache
)。
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
モデルを作成してオプティマイザをインスタンス化する
Strategy.scope
のコンテキスト内で、Keras API を使ってモデルを作成し、コンパイルします。
with strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
metrics=['accuracy'])
この MNIST データセットを使ったトイサンプルでは、Adam オプティマイザのデフォルトの学習率である 0.001 を使用します。
より大規模なデータセットの場合、分散トレーニングの主なメリットはトレーニングステップごとにより多くの学習を行えることです。これは、各ステップがより多くのトレーニングデータを並行して処理するため、(モデルとデータセットの制限内で)より大きな学習率が可能となるためです。
コールバックを定義する
以下の Keras コールバックを定義します。
tf.keras.callbacks.TensorBoard
: グラフを視覚化できるように、TensorBoard 用のログを書き込みます。tf.keras.callbacks.ModelCheckpoint
: 各エポック後など、特定の頻度でモデルを保存します。tf.keras.callbacks.BackupAndRestore
: モデルと現在のエポック番号をバックアップすることで、フォールトトレランス機能を提供します。詳細は、Keras によるマルチワーカートレーニングチュートリアルのフォールトトレランスセクションをご覧ください。tf.keras.callbacks.LearningRateScheduler
: schedules the learning rate to change after, for example, every epoch/batch.
このノートブックでは例示目的で、PrintLR
というカスタムコールバックを追加して、学習率を表示します。
注意: ジョブの失敗から再開する際に、トレーニング状態をリストアするための主なメカニズムとして、ModelCheckpoint
の代わりに BackupAndRestore
コールバックを使用してください。BackupAndRestore
は eager モードのみをサポートするため、graph モードでは ModelCheckpoint
を使用することを検討してください。
# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
if epoch < 3:
return 1e-3
elif epoch >= 3 and epoch < 7:
return 1e-4
else:
return 1e-5
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
def on_epoch_end(self, epoch, logs=None):
print('\nLearning rate for epoch {} is {}'.format( epoch + 1, model.optimizer.lr.numpy()))
# Put all the callbacks together.
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir='./logs'),
tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
save_weights_only=True),
tf.keras.callbacks.LearningRateScheduler(decay),
PrintLR()
]
トレーニングして評価する
次に、通常の方法でモデルをトレーニングします。モデル上で Keras Model.fit
を呼び出し、チュートリアルの最初に作成したデータセットを渡します。トレーニングを分散しているかに関わらず、このステップは同じです。
EPOCHS = 12
model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
2024-01-11 18:19:03.458640: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/12 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1704997150.075966 59417 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1/235 [..............................] - ETA: 24:58 - loss: 2.3099 - accuracy: 0.0938WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0076s vs `on_train_batch_end` time: 0.0139s). Check your callbacks. WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0076s vs `on_train_batch_end` time: 0.0139s). Check your callbacks. 235/235 [==============================] - ETA: 0s - loss: 0.3306 - accuracy: 0.9069INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Learning rate for epoch 1 is 0.0010000000474974513 235/235 [==============================] - 8s 8ms/step - loss: 0.3306 - accuracy: 0.9069 - lr: 0.0010 Epoch 2/12 233/235 [============================>.] - ETA: 0s - loss: 0.1105 - accuracy: 0.9682 Learning rate for epoch 2 is 0.0010000000474974513 235/235 [==============================] - 2s 7ms/step - loss: 0.1104 - accuracy: 0.9683 - lr: 0.0010 Epoch 3/12 230/235 [============================>.] - ETA: 0s - loss: 0.0728 - accuracy: 0.9787 Learning rate for epoch 3 is 0.0010000000474974513 235/235 [==============================] - 2s 7ms/step - loss: 0.0726 - accuracy: 0.9787 - lr: 0.0010 Epoch 4/12 233/235 [============================>.] - ETA: 0s - loss: 0.0509 - accuracy: 0.9860 Learning rate for epoch 4 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0508 - accuracy: 0.9861 - lr: 1.0000e-04 Epoch 5/12 233/235 [============================>.] - ETA: 0s - loss: 0.0482 - accuracy: 0.9868 Learning rate for epoch 5 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0482 - accuracy: 0.9869 - lr: 1.0000e-04 Epoch 6/12 230/235 [============================>.] - ETA: 0s - loss: 0.0463 - accuracy: 0.9874 Learning rate for epoch 6 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0464 - accuracy: 0.9873 - lr: 1.0000e-04 Epoch 7/12 233/235 [============================>.] - ETA: 0s - loss: 0.0448 - accuracy: 0.9878 Learning rate for epoch 7 is 9.999999747378752e-05 235/235 [==============================] - 2s 7ms/step - loss: 0.0447 - accuracy: 0.9878 - lr: 1.0000e-04 Epoch 8/12 233/235 [============================>.] - ETA: 0s - loss: 0.0425 - accuracy: 0.9888 Learning rate for epoch 8 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0424 - accuracy: 0.9887 - lr: 1.0000e-05 Epoch 9/12 232/235 [============================>.] - ETA: 0s - loss: 0.0421 - accuracy: 0.9888 Learning rate for epoch 9 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0422 - accuracy: 0.9888 - lr: 1.0000e-05 Epoch 10/12 233/235 [============================>.] - ETA: 0s - loss: 0.0420 - accuracy: 0.9889 Learning rate for epoch 10 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0420 - accuracy: 0.9888 - lr: 1.0000e-05 Epoch 11/12 233/235 [============================>.] - ETA: 0s - loss: 0.0419 - accuracy: 0.9889 Learning rate for epoch 11 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0418 - accuracy: 0.9888 - lr: 1.0000e-05 Epoch 12/12 232/235 [============================>.] - ETA: 0s - loss: 0.0417 - accuracy: 0.9889 Learning rate for epoch 12 is 9.999999747378752e-06 235/235 [==============================] - 2s 7ms/step - loss: 0.0416 - accuracy: 0.9889 - lr: 1.0000e-05 <keras.src.callbacks.History at 0x7f873008aa00>
保存済みのチェックポイントを確認します。
# Check the checkpoint directory.
ls {checkpoint_dir}
checkpoint ckpt_4.data-00000-of-00001 ckpt_1.data-00000-of-00001 ckpt_4.index ckpt_1.index ckpt_5.data-00000-of-00001 ckpt_10.data-00000-of-00001 ckpt_5.index ckpt_10.index ckpt_6.data-00000-of-00001 ckpt_11.data-00000-of-00001 ckpt_6.index ckpt_11.index ckpt_7.data-00000-of-00001 ckpt_12.data-00000-of-00001 ckpt_7.index ckpt_12.index ckpt_8.data-00000-of-00001 ckpt_2.data-00000-of-00001 ckpt_8.index ckpt_2.index ckpt_9.data-00000-of-00001 ckpt_3.data-00000-of-00001 ckpt_9.index ckpt_3.index
モデルがどれほどうまく実行するかを確認するために、最新のチェックポイントを読み込み、テストデータで Model.evaluate
を呼び出します。
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
2024-01-11 18:19:35.435996: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 40/40 [==============================] - 2s 8ms/step - loss: 0.0527 - accuracy: 0.9815 Eval loss: 0.052657630294561386, Eval accuracy: 0.9815000295639038
出力を視覚化するために、TensorBoard を起動して、ログを表示します。
%tensorboard --logdir=logs
ls -sh ./logs
total 4.0K 4.0K train
モデルを保存する
Model.save
を使用して、モデルを .keras
zip アーカイブに保存します。モデルが保存されたら、Strategy.scope
の有無に関係なくそれを読み込めるようになります。
path = 'my_model.keras'
model.save(path)
次に、Strategy.scope
を使用せずにモデルを読み込みます。
unreplicated_model = tf.keras.models.load_model(path)
unreplicated_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
40/40 [==============================] - 0s 4ms/step - loss: 0.0527 - accuracy: 0.9815 Eval loss: 0.052657630294561386, Eval Accuracy: 0.9815000295639038
Strategy.scope
を使用してモデルを読み込みます。
with strategy.scope():
replicated_model = tf.keras.models.load_model(path)
replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=['accuracy'])
eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
2024-01-11 18:19:38.910821: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 40/40 [==============================] - 2s 5ms/step - loss: 0.0527 - accuracy: 0.9815 Eval loss: 0.052657630294561386, Eval Accuracy: 0.9815000295639038
追加リソース
さまざまな分散ストラテジーと Keras Model.fit
API を使用したその他の例をご覧ください。
- TPU で BERT を使って GLUE タスクを解決するチュートリアルでは、GPU でのトレーニングには
tf.distribute.MirroredStrategy
を使用し、TPU ではtf.distribute.TPUStrategy
を使用しています。 - 分散ストラテジーを使ってモデルを保存して読み込むチュートリアルでは、SavedModel API と
tf.distribute.Strategy
の使用方法が説明されています。 - TensorFlow 公式モデルは、複数の分散ストラテジーを実行できるように構成可能です。
TensorFlow 分散ストラテジーに関してさらに学習するには、以下をご覧ください。
- tf.distribute.Strategy によるカスタムトレーニングチュートリアルでは、カスタムトレーニングループを使って単一ワーカートレーニングに
tf.distribute.MirroredStrategy
を使用する方法が説明されています。 - Keras によるマルチワーカートレーニングのチュートリアルでは、
MultiWorkerMirroredStrategy
とModel.fit
を使用する方法が説明されています。 - Keras によるカスタムトレーニングループと MultiWorkerMirroredStrategy のチュートリアルでは、Keras とカスタムトレーニングループで
MultiWorkerMirroredStrategy
を使用する方法が説明されています。 - TensorFlow での分散型トレーニングガイドでは、利用可能な分散ストラテジーの概要が説明されています。
- tf.function を使ったパフォーマンスの改善ガイドでは、その他のストラテジーや、TensorFlow モデルのパフォーマンスを最適化するために使用できる TensorFlow Profiler といったツールに関する情報が提供されています。
注意: tf.distribute.Strategy
の開発は積極積に進められています。近日中にはより多くの例やチュートリアルを追加する予定ですので、ぜひお試しください。フィードバックをお待ちしております。GitHub の課題から、お気軽にお寄せください。