TensorFlow.org에서 보기 | Google Colab에서 실행하기 | GitHub에서 소스 보기 | 노트북 다운로드하기 |
이 튜토리얼에서는 tf.distribute.Strategy
TensorFlow API와 사용자 정의 훈련 루프를 사용하여 여러 처리 장치(GPU, 여러 머신 또는 TPU)에 훈련을 배포하기 위한 추상화 방법을 보여줍니다. 이 예에서는 28 x 28 크기의 이미지 70,000개가 포함된 Fashion MNIST 데이터세트에서 간단한 컨볼루션 신경망을 훈련합니다.
사용자 정의 훈련 루프는 훈련에 대한 유연성과 더 나은 통제력을 제공합니다. 또한 모델과 훈련 루프를 쉽게 디버그할 수 있습니다.
# Import TensorFlow
import tensorflow as tf
# Helper libraries
import numpy as np
import os
print(tf.__version__)
2022-12-15 01:55:10.915850: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-15 01:55:10.915956: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-15 01:55:10.915967: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. 2.11.0
Fashion MNIST 데이터세트 다운로드하기
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]
# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)
변수와 그래프를 분산하는 전략 만들기
tf.distribute.MirroredStrategy
전략이 어떻게 동작할까요?
- 모든 변수와 모델 그래프는 복제본 간에 복제됩니다.
- 입력은 장치에 고르게 분배되어 들어갑니다.
- 각 장치는 주어지는 입력에 대해서 손실(loss)과 그래디언트를 계산합니다.
- 그래디언트들을 전부 더함으로써 모든 장치들 간에 그래디언트들이 동기화됩니다.
- 동기화된 후에, 동일한 업데이트가 각 장치에 있는 변수의 복사본(copies)에 동일하게 적용됩니다.
참고: 아래의 모든 코드를 단일 범위에 넣을 수 있습니다. 이 예에서는 설명을 위해 여러 코드 셀로 나눕니다.
# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4
입력 파이프라인 설정하기
BUFFER_SIZE = len(train_images)
BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync
EPOCHS = 10
데이터세트를 만들고 배포합니다.
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
2022-12-15 01:55:18.156270: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "_cardinality" value { i: 60000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:0" } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_UINT8 } } } } } 2022-12-15 01:55:18.216329: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:784] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_2" op: "TensorSliceDataset" input: "Placeholder/_0" input: "Placeholder/_1" attr { key: "Toutput_types" value { list { type: DT_FLOAT type: DT_UINT8 } } } attr { key: "_cardinality" value { i: 10000 } } attr { key: "is_files" value { b: false } } attr { key: "metadata" value { s: "\n\024TensorSliceDataset:3" } } attr { key: "output_shapes" value { list { shape { dim { size: 28 } dim { size: 28 } dim { size: 1 } } shape { } } } } attr { key: "replicate_on_split" value { b: false } } experimental_type { type_id: TFT_PRODUCT args { type_id: TFT_DATASET args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_FLOAT } } args { type_id: TFT_TENSOR args { type_id: TFT_UINT8 } } } } }
모델 만들기
tf.keras.Sequential
을 사용하여 모델을 만듭니다. 모델 하위 클래스화 API 또는 함수형 API를 사용하여 이를 수행할 수도 있습니다.
def create_model():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
손실 함수 정의하기
일반적으로 단일 GPU/CPU가 있는 단일 시스템에서 손실 함수는 입력 배치의 예제 수로 나뉩니다.
그렇다면 tf.distribute.Strategy
를 사용할 때 손실을 어떻게 계산해야 할까요?
예를 들어 GPU가 4개 있고 배치 크기가 64라고 가정해 보겠습니다. 하나의 입력 배치가 전체 복제본(4개의 GPU)에 걸쳐 분배되고 각 복제본은 크기 16의 입력을 받습니다.
각 복제본의 모델은 해당 입력으로 순방향 전달을 수행하고 손실을 계산합니다. 이제 손실을 해당 입력의 예제 수로 나누는 대신(BATCH_SIZE_PER_REPLICA = 16), 손실을 GLOBAL_BATCH_SIZE(64)로 나누어야 합니다.
이렇게 하는 이유는 무엇일까요?
- 각 복제본에서 그래디언트가 계산된 후 이를 합산하여 전체 복제본에 걸쳐 동기화되기 때문에 이렇게 해야 합니다.
TensorFlow에서 이 작업을 어떻게 수행할까요?
이 튜토리얼에서와 같이 사용자 지정 훈련 루프를 작성하는 경우 예제당 손실을 합산하고 합계를 GLOBAL_BATCH_SIZE로 나누어야 합니다:
scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE)
또는tf.nn.compute_average_loss
를 사용하여 예제당 손실, 선택적 샘플 가중치, GLOBAL_BATCH_SIZE를 인수로 사용하고 조정된 손실을 반환할 수 있습니다.모델에서 정규화 손실을 사용하는 경우 복제본 수에 따라 손실 값을 확장해야 합니다.
tf.nn.scale_regularization_loss
함수를 사용하여 이를 수행할 수 있습니다.tf.reduce_mean
을 사용하는 것은 권장하지 않습니다. 이렇게 하면 손실이 실제 복제본 배치 크기별로 나눠지는데, 이는 단계별로 다를 수 있습니다.이 축소 및 크기 조정은 keras
model.compile
및model.fit
에서 자동으로 수행됩니다.tf.keras.losses
클래스를 사용하는 경우(아래 예와 같이) 손실 축소는NONE
또는SUM
중 하나로 명시적으로 지정되어야 합니다.AUTO
및SUM_OVER_BATCH_SIZE
는tf.distribute.Strategy
와 함께 사용할 때 허용되지 않습니다.AUTO
는 사용자가 어떤 축소가 필요한지 명시적으로 생각해야 하므로(분산된 경우에 이러한 축소가 올바른지 확인하기 위해) 허용되지 않습니다.SUM_OVER_BATCH_SIZE
는 현재 복제본 배치 크기별로만 나누고 복제본 수로 나누는 일은 사용자에게 맡기는 데, 이를 쉽게 놓칠 수 있다는 점 때문에 허용되지 않습니다. 따라서 사용자가 직접 명시적으로 축소를 수행해야 합니다.labels
이 다차원인 경우 각 샘플의 요소 수에 걸쳐per_example_loss
의 평균을 구합니다. 예를 들어predictions
의 형상이(batch_size, H, W, n_classes)
이고labels
이(batch_size, H, W)
인 경우, 다음과 같이per_example_loss
를 업데이트해야 합니다:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)
주의: 손실의 형상을 확인하세요.
tf.losses
/tf.keras.losses
의 손실 함수는 일반적으로 입력의 마지막 차원에 대한 평균을 반환합니다. 손실 클래스는 이러한 함수를 래핑합니다. 손실 클래스의 인스턴스를 생성할 때reduction=Reduction.NONE
을 전달하는 것은 "추가적인 축소가 없음"을 의미합니다.[batch, W, H, n_classes]
의 예제 입력 형상을 갖는 범주형 손실의 경우,n_classes
차원이 축소됩니다.losses.mean_squared_error
또는losses.binary_crossentropy
와 같은 포인트별 손실의 경우, 더미 축을 포함시켜[batch, W, H, 1]
이[batch, W, H]
로 축소되도록 합니다. 더미 축이 없으면[batch, W, H]
가[batch, W]
로 잘못 축소됩니다.
with strategy.scope():
# Set reduction to `NONE` so you can do the reduction afterwards and divide by
# global batch size.
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True,
reduction=tf.keras.losses.Reduction.NONE)
def compute_loss(labels, predictions):
per_example_loss = loss_object(labels, predictions)
return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
손실과 정확도를 기록하기 위한 지표 정의하기
이 지표(metrics)는 테스트 손실과 훈련 정확도, 테스트 정확도를 기록합니다. .result()
를 사용해서 누적된 통계값들을 언제나 볼 수 있습니다.
with strategy.scope():
test_loss = tf.keras.metrics.Mean(name='test_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='train_accuracy')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='test_accuracy')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
훈련 루프
# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
model = create_model()
optimizer = tf.keras.optimizers.Adam()
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
images, labels = inputs
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = compute_loss(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_accuracy.update_state(labels, predictions)
return loss
def test_step(inputs):
images, labels = inputs
predictions = model(images, training=False)
t_loss = loss_object(labels, predictions)
test_loss.update_state(t_loss)
test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
axis=None)
@tf.function
def distributed_test_step(dataset_inputs):
return strategy.run(test_step, args=(dataset_inputs,))
for epoch in range(EPOCHS):
# TRAIN LOOP
total_loss = 0.0
num_batches = 0
for x in train_dist_dataset:
total_loss += distributed_train_step(x)
num_batches += 1
train_loss = total_loss / num_batches
# TEST LOOP
for x in test_dist_dataset:
distributed_test_step(x)
if epoch % 2 == 0:
checkpoint.save(checkpoint_prefix)
template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
"Test Accuracy: {}")
print(template.format(epoch + 1, train_loss,
train_accuracy.result() * 100, test_loss.result(),
test_accuracy.result() * 100))
test_loss.reset_states()
train_accuracy.reset_states()
test_accuracy.reset_states()
INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 Epoch 1, Loss: 0.6803293824195862, Accuracy: 75.58499908447266, Test Loss: 0.47986483573913574, Test Accuracy: 82.55000305175781 Epoch 2, Loss: 0.41395682096481323, Accuracy: 85.15166473388672, Test Loss: 0.4264320433139801, Test Accuracy: 84.72999572753906 Epoch 3, Loss: 0.3555366098880768, Accuracy: 87.24166870117188, Test Loss: 0.3662477135658264, Test Accuracy: 86.47999572753906 Epoch 4, Loss: 0.3242177665233612, Accuracy: 88.31499481201172, Test Loss: 0.34345629811286926, Test Accuracy: 88.02000427246094 Epoch 5, Loss: 0.30220723152160645, Accuracy: 89.08499908447266, Test Loss: 0.33098432421684265, Test Accuracy: 87.98999786376953 Epoch 6, Loss: 0.28694629669189453, Accuracy: 89.50666809082031, Test Loss: 0.31467700004577637, Test Accuracy: 88.5999984741211 Epoch 7, Loss: 0.2713523507118225, Accuracy: 90.05332946777344, Test Loss: 0.3279726803302765, Test Accuracy: 88.04000091552734 Epoch 8, Loss: 0.2599695324897766, Accuracy: 90.46833801269531, Test Loss: 0.2958003878593445, Test Accuracy: 89.38999938964844 Epoch 9, Loss: 0.2476484477519989, Accuracy: 90.95999908447266, Test Loss: 0.3058784306049347, Test Accuracy: 88.55999755859375 Epoch 10, Loss: 0.23547829687595367, Accuracy: 91.32500457763672, Test Loss: 0.284298300743103, Test Accuracy: 89.59000396728516
위의 예제에서 주목해야 하는 부분
for x in ...
구문을 사용하여train_dist_dataset
및test_dist_dataset
를 반복합니다.- 스케일이 조정된 손실은
distributed_train_step
의 반환값입니다.tf.distribute.Strategy.reduce
호출을 사용해서 장치들 간의 스케일이 조정된 손실 값을 전부 합칩니다. 그리고 나서tf.distribute.Strategy.reduce
반환 값을 더하는 식으로 배치 간의 손실을 모읍니다. tf.keras.Metrics
는tf.distribute.Strategy.run
에 의해 실행되는train_step
및test_step
내부에서 업데이트되어야 합니다.tf.distribute.Strategy.run
은 전략의 각 로컬 복제본에서 결과를 반환하며 이 결과를 사용하는 방법에는 여러 가지가 있습니다.tf.distribute.Strategy.reduce
를 수행하여 집계된 값을 얻을 수 있습니다.tf.distribute.Strategy.experimental_local_results
를 수행하여 로컬 복제본당 하나씩 결과에 포함된 값 목록을 가져올 수도 있습니다.
최신 체크포인트를 불러와서 테스트하기
tf.distribute.Strategy
를 사용해서 체크포인트가 만들어진 모델은 전략 사용 여부에 상관없이 불러올 수 있습니다.
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
name='eval_accuracy')
new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
@tf.function
def eval_step(images, labels):
predictions = new_model(images, training=False)
eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
for images, labels in test_dataset:
eval_step(images, labels)
print('Accuracy after restoring the saved model without strategy: {}'.format(
eval_accuracy.result() * 100))
Accuracy after restoring the saved model without strategy: 88.55999755859375
데이터셋에 대해 반복작업을 하는 다른 방법들
반복자(iterator)를 사용하기
만약 주어진 스텝의 수에 따라서 반복하기 원하면서 전체 데이터세트를 보는 것을 원치 않는다면, iter
를 호출하여 반복자를 만들 수 있습니다. 그 다음 명시적으로 next
를 호출합니다. 또한, tf.function 내부 또는 외부에서 데이터세트를 반복하도록 설정할 수 있습니다. 다음은 반복자를 사용하여 tf.function 외부에서 데이터세트를 반복하는 코드 예제입니다.
for _ in range(EPOCHS):
total_loss = 0.0
num_batches = 0
train_iter = iter(train_dist_dataset)
for _ in range(10):
total_loss += distributed_train_step(next(train_iter))
num_batches += 1
average_train_loss = total_loss / num_batches
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
train_accuracy.reset_states()
Epoch 10, Loss: 0.24009506404399872, Accuracy: 90.78125 Epoch 10, Loss: 0.22674480080604553, Accuracy: 91.796875 Epoch 10, Loss: 0.22516441345214844, Accuracy: 91.7578125 Epoch 10, Loss: 0.24468576908111572, Accuracy: 90.546875 Epoch 10, Loss: 0.22958669066429138, Accuracy: 91.8359375 Epoch 10, Loss: 0.2196839302778244, Accuracy: 92.1875 Epoch 10, Loss: 0.2384246289730072, Accuracy: 90.9765625 Epoch 10, Loss: 0.21778538823127747, Accuracy: 91.953125 Epoch 10, Loss: 0.20954498648643494, Accuracy: 91.9921875 Epoch 10, Loss: 0.18693020939826965, Accuracy: 93.0859375
tf.function 내부에서 반복하기
for x in ...
구조를 사용하거나 위에서 했던 것처럼 반복자를 생성하여 tf.function
내부의 전체 입력 train_dist_dataset
에 대해 반복할 수도 있습니다. 아래 예제는 @tf.function
데코레이터로 훈련한 하나의 epoch를 래핑하고 함수 내에서 train_dist_dataset
을 반복하는 것을 보여줍니다.
@tf.function
def distributed_train_epoch(dataset):
total_loss = 0.0
num_batches = 0
for x in dataset:
per_replica_losses = strategy.run(train_step, args=(x,))
total_loss += strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
num_batches += 1
return total_loss / tf.cast(num_batches, dtype=tf.float32)
for epoch in range(EPOCHS):
train_loss = distributed_train_epoch(train_dist_dataset)
template = ("Epoch {}, Loss: {}, Accuracy: {}")
print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))
train_accuracy.reset_states()
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py:461: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options. warnings.warn("To make it possible to preserve tf.data options across " INFO:tensorflow:batch_all_reduce: 8 all-reduces with algorithm = nccl, num_packs = 1 Epoch 1, Loss: 0.22137457132339478, Accuracy: 91.8316650390625 Epoch 2, Loss: 0.21201039850711823, Accuracy: 92.2366714477539 Epoch 3, Loss: 0.20279081165790558, Accuracy: 92.58499908447266 Epoch 4, Loss: 0.1953163594007492, Accuracy: 92.71666717529297 Epoch 5, Loss: 0.18428874015808105, Accuracy: 93.19999694824219 Epoch 6, Loss: 0.17805488407611847, Accuracy: 93.32833099365234 Epoch 7, Loss: 0.17244744300842285, Accuracy: 93.54000091552734 Epoch 8, Loss: 0.16432566940784454, Accuracy: 93.961669921875 Epoch 9, Loss: 0.1565566509962082, Accuracy: 94.34833526611328 Epoch 10, Loss: 0.15320821106433868, Accuracy: 94.36499786376953
장치 간의 훈련 손실 기록하기
노트: 일반적인 규칙으로, tf.keras.Metrics
를 사용하여 샘플당 손실 값을 기록하고 장치 내부에서 값이 합쳐지는 것을 피해야 합니다.
손실 크기 조정 계산이 수행되기 때문에 tf.keras.metrics.Mean
을 사용하여 여러 복제본에서 훈련 손실을 추적하는 것은 권장되지 않습니다.
예를 들어, 다음과 같은 조건의 훈련을 수행한다고 합시다.
- 두개의 장치
- 두개의 샘플들이 각 장치에 의해 처리됩니다.
- 손실 값을 산출합니다: 각각의 장치에 대해 [2, 3]과 [4, 5]
- Global batch size = 4
손실의 스케일 조정을 하면, 손실 값을 더하고 전역 배치 크기로 나누어 각 장치에 대한 샘플당 손실값을 계산할 수 있습니다. 이 경우에는 (2 + 3) / 4 = 1.25
및 (4 + 5) / 4 = 2.25
입니다.
만약 tf.keras.metrics.Mean
을 사용하여 두 복제본의 손실을 추적한다면 결과가 다릅니다. 이 예에서는 결과적으로 메트릭에서 result()
가 호출될 때 total
3.50개와 count
2개가 되고, total
/count
= 1.75가 됩니다. tf.keras.Metrics
로 계산된 손실은 동기화된 복제본 수에 해당하는 추가 인자로 조정됩니다.
예제와 튜토리얼
사용자 정의 훈련루프를 포함한 분산 전략을 사용하는 몇 가지 예제가 있습니다.
- 분산 교육 가이드
MirroredStrategy
를 사용하는 DenseNet 예제.MirroredStrategy
및TPUStrategy
를 사용해서 훈련하는 BERT 예제. 이 예제는 분산 훈련 중에 어떻게 체크포인트로부터 불러오는지와 어떻게 주기적으로 체크포인트를 생성해 내는 지를 이해하기에 정말 좋습니다.keras_use_ctl flag
를 사용해서 활성화 할 수 있는 MirroredStrategy를 이용해서 훈련되는 NCF 예제MirroredStrategy
을 사용해서 훈련되는 NMT 예제.
배포 전략 가이드의 예제 및 튜토리얼에서 더 많은 예제를 찾을 수 있습니다.
다음 단계
- 모델에서 새로운
tf.distribute.Strategy
API를 사용해 보세요. - TensorFlow 모델의 성능을 최적화하는 도구에 대해 자세히 알아보려면
tf.function
을 이용한 성능 향상 및 TensorFlow 프로파일러 가이드를 참조하세요. - 사용 가능한 배포 전략에 대한 개요를 제공하는 TensorFlow의 분산 훈련 가이드를 확인하세요.