Keras による分散型トレーニング

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

概要

tf.distribute.Strategy API は、複数の処理ユニットに渡ってトレーニングを分散するための抽象化を提供します。ユーザーは既存のモデルとトレーニングコードを使用して、最小限の変更で分散型トレーニングを実行できるようになります。

このチュートリアルでは、tf.distribute.MirroredStrategy を使用して、1 台のマシンの多数の GPU で同期トレーニングを行うグラフ内レプリケーションを実行します。ストラテジーは基本的にモデルのすべての変数を各プロセッサにコピーします。その後、all-reduce を使用して全プロセッサからの勾配を結合し、結合された値をモデルの全コピーに適用します。

tf.keras API を使用して、モデルとそれをトレーニングするための Model.fit をビルドします。(カスタムトレーニングループと MirroredStrategy を使った分散型トレーニングについては、こちらのチュートリアルをご覧ください。)

MirroredStrategy は単一のマシンの複数の GPU でモデルをトレーニングします。複数のワーカーの多数の GPU で同期トレーニングを行う場合は、tf.distribute.MultiWorkerMirroredStrategyKeras の Model.fitカスタムトレーニングループを使用します。その他のオプションについては、分散型トレーニングガイドをご覧ください。

その他のさまざまなストラテジーについては、TensorFlow の分散型トレーニングガイドをご覧ください。

セットアップ

import tensorflow_datasets as tfds
import tensorflow as tf

import os

# Load the TensorBoard notebook extension.
%load_ext tensorboard
2024-01-11 18:18:57.406481: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-01-11 18:18:57.406527: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-01-11 18:18:57.408011: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
print(tf.__version__)
2.15.0

データセットをダウンロードする

TensorFlow Datasets から MNIST データセットを読み込みます。これは、tf.data 形式のデータセットを返します。

with_info 引数を True に設定すると、データセット全体に対するメタデータが含まれます。ここでは info に保存されます。このメタデータオブジェクトには、トレーニングとテストの例の数などが含まれます。

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

mnist_train, mnist_test = datasets['train'], datasets['test']

分散ストラテジーを定義する

MirroredStrategy オブジェクトを作成します。これは分散を処理し、モデル内に構築するコンテキストマネージャ (MirroredStrategy.scope) を提供します。

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4

入力パイプラインをセットアップする

マルチ GPU でモデルをトレーニングする場合、バッチサイズを増加することにより追加の計算能力を効果的に利用することができます。一般的には、GPU メモリに収まる最大のバッチサイズを使用し、それに応じて学習率を調整します。

# You can also do info.splits.total_num_examples to get the total
# number of examples in the dataset.

num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

画像ピクセル値を [0, 255] の範囲から [0, 1] の範囲に正規化する関数を定義します(特徴量スケーリング)。

def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

この scale 関数をトレーニングとテストのデータに適用してから、tf.data.Dataset API を使用してトレーニングデータをシャッフル(Dataset.shuffle)し、バッチ化(Dataset.batch)します。パフォーマンスを改善するために、トレーニングデータのインメモリキャッシュも保持していることに注意してください(Dataset.cache)。

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

モデルを作成してオプティマイザをインスタンス化する

Strategy.scope のコンテキスト内で、Keras API を使ってモデルを作成し、コンパイルします。

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10)
  ])

  model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
                metrics=['accuracy'])

この MNIST データセットを使ったトイサンプルでは、Adam オプティマイザのデフォルトの学習率である 0.001 を使用します。

より大規模なデータセットの場合、分散トレーニングの主なメリットはトレーニングステップごとにより多くの学習を行えることです。これは、各ステップがより多くのトレーニングデータを並行して処理するため、(モデルとデータセットの制限内で)より大きな学習率が可能となるためです。

コールバックを定義する

以下の Keras コールバックを定義します。

このノートブックでは例示目的で、PrintLR というカスタムコールバックを追加して、学習率を表示します。

注意: ジョブの失敗から再開する際に、トレーニング状態をリストアするための主なメカニズムとして、ModelCheckpoint の代わりに BackupAndRestore コールバックを使用してください。BackupAndRestore は eager モードのみをサポートするため、graph モードでは ModelCheckpoint を使用することを検討してください。

# Define the checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
# Define the name of the checkpoint files.
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Define a function for decaying the learning rate.
# You can define any decay function you need.
def decay(epoch):
  if epoch < 3:
    return 1e-3
  elif epoch >= 3 and epoch < 7:
    return 1e-4
  else:
    return 1e-5
# Define a callback for printing the learning rate at the end of each epoch.
class PrintLR(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print('\nLearning rate for epoch {} is {}'.format(        epoch + 1, model.optimizer.lr.numpy()))
# Put all the callbacks together.
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

トレーニングして評価する

次に、通常の方法でモデルをトレーニングします。モデル上で Keras Model.fit を呼び出し、チュートリアルの最初に作成したデータセットを渡します。トレーニングを分散しているかに関わらず、このステップは同じです。

EPOCHS = 12

model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks)
2024-01-11 18:19:03.458640: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
Epoch 1/12
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1704997150.075966   59417 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
1/235 [..............................] - ETA: 24:58 - loss: 2.3099 - accuracy: 0.0938WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0076s vs `on_train_batch_end` time: 0.0139s). Check your callbacks.
WARNING:tensorflow:Callback method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0076s vs `on_train_batch_end` time: 0.0139s). Check your callbacks.
235/235 [==============================] - ETA: 0s - loss: 0.3306 - accuracy: 0.9069INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Learning rate for epoch 1 is 0.0010000000474974513
235/235 [==============================] - 8s 8ms/step - loss: 0.3306 - accuracy: 0.9069 - lr: 0.0010
Epoch 2/12
233/235 [============================>.] - ETA: 0s - loss: 0.1105 - accuracy: 0.9682
Learning rate for epoch 2 is 0.0010000000474974513
235/235 [==============================] - 2s 7ms/step - loss: 0.1104 - accuracy: 0.9683 - lr: 0.0010
Epoch 3/12
230/235 [============================>.] - ETA: 0s - loss: 0.0728 - accuracy: 0.9787
Learning rate for epoch 3 is 0.0010000000474974513
235/235 [==============================] - 2s 7ms/step - loss: 0.0726 - accuracy: 0.9787 - lr: 0.0010
Epoch 4/12
233/235 [============================>.] - ETA: 0s - loss: 0.0509 - accuracy: 0.9860
Learning rate for epoch 4 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0508 - accuracy: 0.9861 - lr: 1.0000e-04
Epoch 5/12
233/235 [============================>.] - ETA: 0s - loss: 0.0482 - accuracy: 0.9868
Learning rate for epoch 5 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0482 - accuracy: 0.9869 - lr: 1.0000e-04
Epoch 6/12
230/235 [============================>.] - ETA: 0s - loss: 0.0463 - accuracy: 0.9874
Learning rate for epoch 6 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0464 - accuracy: 0.9873 - lr: 1.0000e-04
Epoch 7/12
233/235 [============================>.] - ETA: 0s - loss: 0.0448 - accuracy: 0.9878
Learning rate for epoch 7 is 9.999999747378752e-05
235/235 [==============================] - 2s 7ms/step - loss: 0.0447 - accuracy: 0.9878 - lr: 1.0000e-04
Epoch 8/12
233/235 [============================>.] - ETA: 0s - loss: 0.0425 - accuracy: 0.9888
Learning rate for epoch 8 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0424 - accuracy: 0.9887 - lr: 1.0000e-05
Epoch 9/12
232/235 [============================>.] - ETA: 0s - loss: 0.0421 - accuracy: 0.9888
Learning rate for epoch 9 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0422 - accuracy: 0.9888 - lr: 1.0000e-05
Epoch 10/12
233/235 [============================>.] - ETA: 0s - loss: 0.0420 - accuracy: 0.9889
Learning rate for epoch 10 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0420 - accuracy: 0.9888 - lr: 1.0000e-05
Epoch 11/12
233/235 [============================>.] - ETA: 0s - loss: 0.0419 - accuracy: 0.9889
Learning rate for epoch 11 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0418 - accuracy: 0.9888 - lr: 1.0000e-05
Epoch 12/12
232/235 [============================>.] - ETA: 0s - loss: 0.0417 - accuracy: 0.9889
Learning rate for epoch 12 is 9.999999747378752e-06
235/235 [==============================] - 2s 7ms/step - loss: 0.0416 - accuracy: 0.9889 - lr: 1.0000e-05
<keras.src.callbacks.History at 0x7f873008aa00>

保存済みのチェックポイントを確認します。

# Check the checkpoint directory.
ls {checkpoint_dir}
checkpoint           ckpt_4.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_4.index
ckpt_1.index             ckpt_5.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_5.index
ckpt_10.index            ckpt_6.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_6.index
ckpt_11.index            ckpt_7.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_7.index
ckpt_12.index            ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index             ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index

モデルがどれほどうまく実行するかを確認するために、最新のチェックポイントを読み込み、テストデータで Model.evaluate を呼び出します。

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

eval_loss, eval_acc = model.evaluate(eval_dataset)

print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc))
2024-01-11 18:19:35.435996: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 8ms/step - loss: 0.0527 - accuracy: 0.9815
Eval loss: 0.052657630294561386, Eval accuracy: 0.9815000295639038

出力を視覚化するために、TensorBoard を起動して、ログを表示します。

%tensorboard --logdir=logs

ls -sh ./logs
total 4.0K
4.0K train

モデルを保存する

Model.save を使用して、モデルを .keras zip アーカイブに保存します。モデルが保存されたら、Strategy.scope の有無に関係なくそれを読み込めるようになります。

path = 'my_model.keras'
model.save(path)

次に、Strategy.scope を使用せずにモデルを読み込みます。

unreplicated_model = tf.keras.models.load_model(path)

unreplicated_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=tf.keras.optimizers.Adam(),
    metrics=['accuracy'])

eval_loss, eval_acc = unreplicated_model.evaluate(eval_dataset)

print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
40/40 [==============================] - 0s 4ms/step - loss: 0.0527 - accuracy: 0.9815
Eval loss: 0.052657630294561386, Eval Accuracy: 0.9815000295639038

Strategy.scope を使用してモデルを読み込みます。

with strategy.scope():
  replicated_model = tf.keras.models.load_model(path)
  replicated_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                           optimizer=tf.keras.optimizers.Adam(),
                           metrics=['accuracy'])

  eval_loss, eval_acc = replicated_model.evaluate(eval_dataset)
  print ('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))
2024-01-11 18:19:38.910821: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.
40/40 [==============================] - 2s 5ms/step - loss: 0.0527 - accuracy: 0.9815
Eval loss: 0.052657630294561386, Eval Accuracy: 0.9815000295639038

追加リソース

さまざまな分散ストラテジーと Keras Model.fit API を使用したその他の例をご覧ください。

  1. TPU で BERT を使って GLUE タスクを解決するチュートリアルでは、GPU でのトレーニングには tf.distribute.MirroredStrategy を使用し、TPU では tf.distribute.TPUStrategy を使用しています。
  2. 分散ストラテジーを使ってモデルを保存して読み込むチュートリアルでは、SavedModel API と tf.distribute.Strategy の使用方法が説明されています。
  3. TensorFlow 公式モデルは、複数の分散ストラテジーを実行できるように構成可能です。

TensorFlow 分散ストラテジーに関してさらに学習するには、以下をご覧ください。

  1. tf.distribute.Strategy によるカスタムトレーニングチュートリアルでは、カスタムトレーニングループを使って単一ワーカートレーニングに tf.distribute.MirroredStrategy を使用する方法が説明されています。
  2. Keras によるマルチワーカートレーニングのチュートリアルでは、MultiWorkerMirroredStrategyModel.fit を使用する方法が説明されています。
  3. Keras によるカスタムトレーニングループと MultiWorkerMirroredStrategy のチュートリアルでは、Keras とカスタムトレーニングループでMultiWorkerMirroredStrategy を使用する方法が説明されています。
  4. TensorFlow での分散型トレーニングガイドでは、利用可能な分散ストラテジーの概要が説明されています。
  5. tf.function を使ったパフォーマンスの改善ガイドでは、その他のストラテジーや、TensorFlow モデルのパフォーマンスを最適化するために使用できる TensorFlow Profiler といったツールに関する情報が提供されています。

注意: tf.distribute.Strategy の開発は積極積に進められています。近日中にはより多くの例やチュートリアルを追加する予定ですので、ぜひお試しください。フィードバックをお待ちしております。GitHub の課題から、お気軽にお寄せください。