使用 tf.distribute.Strategy 进行自定义训练

本教程演示了如何使用具有自定义训练循环的 TensorFlow API tf.distribute.Strategy,它提供了一种用于在多个处理单元(GPU、多台机器或 TPU)之间分配训练的抽象。在此示例中,将在 Fashion MNIST 数据集上训练一个简单的卷积神经网络,此数据集包含 70,000 个大小为 28 x 28 的图像。


# Import TensorFlow
import tensorflow as tf

# Helper libraries
import numpy as np
import os

下载 Fashion MNIST 数据集

fashion_mnist = tf.keras.datasets.fashion_mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# Add a dimension to the array -> new shape == (28, 28, 1)
# This is done because the first layer in our model is a convolutional
# layer and it requires a 4D input (batch_size, height, width, channels).
# batch_size dimension will be added later on.
train_images = train_images[..., None]
test_images = test_images[..., None]

# Scale the images to the [0, 1] range.
train_images = train_images / np.float32(255)
test_images = test_images / np.float32(255)


tf.distribute.MirroredStrategy 策略是如何运作的?

  • 所有变量和模型计算图都会在副本之间复制。
  • 输入都均匀分布在副本中。
  • 每个副本在收到输入后计算输入的损失和梯度。
  • 通过求和,每一个副本上的梯度都能同步。
  • 同步后,每个副本上的复制的变量都可以同样更新。


# If the list of devices is not specified in
# `tf.distribute.MirroredStrategy` constructor, they will be auto-detected.
strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 4


BUFFER_SIZE = len(train_images)

GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync



train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE) 
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE) 

train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)


使用 tf.keras.Sequential 创建模型。也可以使用模型子类化 API函数式 API 来完成此操作。

def create_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu'),
      tf.keras.layers.Conv2D(64, 3, activation='relu'),
      tf.keras.layers.Dense(64, activation='relu'),

  return model
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")


通常,在具有单个 GPU/CPU 的单台机器上,损失函数会除以输入批次中的样本数。

因此,使用 tf.distribute.Strategy 时应如何计算损失?

  • 例如,假设有 4 个 GPU,批量大小为 64。一个批次的输入会分布在各个副本(4 个 GPU)上,每个副本获得一个大小为 16 的输入。

  • 每个副本上的模型都会使用其各自的输入进行前向传递,并计算损失。现在,不将损失除以其相应输入中的样本数 (BATCH_SIZE_PER_REPLICA = 16),而应将损失除以 GLOBAL_BATCH_SIZE (64)。


  • 之所以需要这样做,是因为在每个副本上计算完梯度后,会通过对梯度求和在副本之间同步梯度。

如何在 TensorFlow 中执行此操作?

  • 如果您正在编写自定义训练循环(如本教程中所述),则应将每个样本的损失相加,然后将总和除以 GLOBAL_BATCH_SIZE: scale_loss = tf.reduce_sum(loss) * (1. / GLOBAL_BATCH_SIZE),或者您可以使用 tf.nn.compute_average_loss,它会将每个样本的损失、可选样本权重和 GLOBAL_BATCH_SIZE 作为参数,并返回经过缩放的损失。

  • 如果在模型中使用正则化损失,则需要按副本数缩放损失值。可以使用 tf.nn.scale_regularization_loss 函数进行此操作。

  • 不建议使用 tf.reduce_mean。这样做会将损失除以实际的每个副本批次大小,该大小可能会随着步骤的不同而发生变化。

  • 这种归约和缩放会在 Keras Model.compileModel.fit 中自动完成。

  • 如果使用 tf.keras.losses 类(如下面的示例所示),则需要将损失归约显式指定为 NONESUM。与 tf.distribute.Strategy 一起使用时,不允许使用 AUTOSUM_OVER_BATCH_SIZE。不允许使用 AUTO,因为用户应明确考虑他们想要的归约量,以确保在分布式情况下归约量正确。不允许使用 SUM_OVER_BATCH_SIZE,因为当前它只能按副本批次大小进行划分,而将按副本数量划分留给用户,这可能很容易遗漏。因此,您需要自己显式执行归约操作。

  • 如果 labels 为多维,则对每个样本中的元素数量的 per_example_loss 求平均值。例如,如果 predictions 的形状为 (batch_size, H, W, n_classes),而 labels(batch_size, H, W),则需要更新 per_example_loss,例如:per_example_loss /= tf.cast(tf.reduce_prod(tf.shape(labels)[1:]), tf.float32)

    小心:验证损失的形状tf.losses/tf.keras.losses 中的损失函数通常会返回输入最后一个维度的平均值。损失类封装这些函数。在创建损失类的实例时传递 reduction=Reduction.NONE,表示“无额外缩减”。对于样本输入形状为 [batch, W, H, n_classes] 的类别损失,会缩减 n_classes 维度。对于类似 losses.mean_squared_errorlosses.binary_crossentropy 的逐点损失,应包含一个虚拟轴,使 [batch, W, H, 1] 缩减为 [batch, W, H]。如果没有虚拟轴,则 [batch, W, H] 将被错误地缩减为 [batch, W]

with strategy.scope():
  # Set reduction to `NONE` so you can do the reduction afterwards and divide by
  # global batch size.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
  def compute_loss(labels, predictions):
    per_example_loss = loss_object(labels, predictions)
    return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)


这些指标可以跟踪测试的损失,训练和测试的准确性。 您可以使用.result()随时获取累积的统计信息。

with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(


# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = create_model()

  optimizer = tf.keras.optimizers.Adam()

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss 

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_accuracy.update_state(labels, predictions)
# `run` replicates the provided computation and runs it
# with the distributed input.
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,

def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  for x in train_dist_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  for x in test_dist_dataset:

  if epoch % 2 == 0:

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print(template.format(epoch + 1, train_loss,
                         train_accuracy.result() * 100, test_loss.result(),
                         test_accuracy.result() * 100))

Epoch 1, Loss: 0.6526154279708862, Accuracy: 76.88500213623047, Test Loss: 0.463595449924469, Test Accuracy: 83.1300048828125
Epoch 2, Loss: 0.4034326672554016, Accuracy: 85.66166687011719, Test Loss: 0.3861437141895294, Test Accuracy: 86.69999694824219
Epoch 3, Loss: 0.3507663607597351, Accuracy: 87.48999786376953, Test Loss: 0.3586585521697998, Test Accuracy: 87.69999694824219
Epoch 4, Loss: 0.31498587131500244, Accuracy: 88.6683349609375, Test Loss: 0.34132587909698486, Test Accuracy: 87.88999938964844
Epoch 5, Loss: 0.29295358061790466, Accuracy: 89.3949966430664, Test Loss: 0.32417431473731995, Test Accuracy: 88.3499984741211
Epoch 6, Loss: 0.2780560553073883, Accuracy: 89.86499786376953, Test Loss: 0.3235797882080078, Test Accuracy: 88.6300048828125
Epoch 7, Loss: 0.26235538721084595, Accuracy: 90.52333068847656, Test Loss: 0.2898319959640503, Test Accuracy: 89.52000427246094
Epoch 8, Loss: 0.25144392251968384, Accuracy: 90.84666442871094, Test Loss: 0.295357882976532, Test Accuracy: 89.11000061035156
Epoch 9, Loss: 0.23802918195724487, Accuracy: 91.27166748046875, Test Loss: 0.2847434878349304, Test Accuracy: 89.9800033569336
Epoch 10, Loss: 0.22753538191318512, Accuracy: 91.67832946777344, Test Loss: 0.2718784213066101, Test Accuracy: 90.09000396728516



使用 tf.distribute.Strategy 设置了检查点的模型可以使用或不使用策略进行恢复。

eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
def eval_step(images, labels):
  predictions = new_model(images, training=False)
  eval_accuracy(labels, predictions)
checkpoint = tf.train.Checkpoint(optimizer=new_optimizer, model=new_model)

for images, labels in test_dataset:
  eval_step(images, labels)

print('Accuracy after restoring the saved model without strategy: {}'.format(
    eval_accuracy.result() * 100))
Accuracy after restoring the saved model without strategy: 89.9800033569336



如果要迭代给定的步数而不是遍历整个数据集,可以使用 iter 调用创建一个迭代器,并在该迭代器上显式地调用 next。您可以选择在 tf.function 内部和外部迭代数据集。下面是一个小代码段,演示了使用迭代器在 tf.function 外部迭代数据集。

for _ in range(EPOCHS):
  total_loss = 0.0
  num_batches = 0
  train_iter = iter(train_dist_dataset)

  for _ in range(10):
    total_loss += distributed_train_step(next(train_iter))
    num_batches += 1
  average_train_loss = total_loss / num_batches

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, average_train_loss, train_accuracy.result() * 100))
Epoch 10, Loss: 0.20245365798473358, Accuracy: 92.5
Epoch 10, Loss: 0.22390851378440857, Accuracy: 91.796875
Epoch 10, Loss: 0.2280225306749344, Accuracy: 91.5234375
Epoch 10, Loss: 0.2149706780910492, Accuracy: 92.265625
Epoch 10, Loss: 0.19533059000968933, Accuracy: 92.96875
Epoch 10, Loss: 0.20433831214904785, Accuracy: 92.421875
Epoch 10, Loss: 0.20081932842731476, Accuracy: 92.5
Epoch 10, Loss: 0.22768919169902802, Accuracy: 91.484375
Epoch 10, Loss: 0.24364034831523895, Accuracy: 91.09375
Epoch 10, Loss: 0.20831342041492462, Accuracy: 92.96875

在 tf.function 中迭代

您还可以使用 for x in ... 构造在 tf.function 内部迭代整个输入 train_dist_dataset,或者像上面那样创建迭代器。下面的示例演示了使用 @tf.function 装饰器封装一个训练周期并在函数内部迭代 train_dist_dataset

def distributed_train_epoch(dataset):
  total_loss = 0.0
  num_batches = 0
  for x in dataset:
    per_replica_losses = strategy.run(train_step, args=(x,))
    total_loss += strategy.reduce(
      tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    num_batches += 1
  return total_loss / tf.cast(num_batches, dtype=tf.float32)

for epoch in range(EPOCHS):
  train_loss = distributed_train_epoch(train_dist_dataset)

  template = ("Epoch {}, Loss: {}, Accuracy: {}")
  print(template.format(epoch + 1, train_loss, train_accuracy.result() * 100))

INFO:tensorflow:Collective all_reduce tensors: 8 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1
Epoch 1, Loss: 0.21631334722042084, Accuracy: 91.98333740234375
Epoch 2, Loss: 0.20537973940372467, Accuracy: 92.4749984741211
Epoch 3, Loss: 0.19430002570152283, Accuracy: 92.9050064086914
Epoch 4, Loss: 0.18660134077072144, Accuracy: 93.16999816894531
Epoch 5, Loss: 0.1797751635313034, Accuracy: 93.44666290283203
Epoch 6, Loss: 0.1693052053451538, Accuracy: 93.79000091552734
Epoch 7, Loss: 0.16252641379833221, Accuracy: 94.12333679199219
Epoch 8, Loss: 0.15544697642326355, Accuracy: 94.31666564941406
Epoch 9, Loss: 0.149961918592453, Accuracy: 94.52833557128906
Epoch 10, Loss: 0.13953761756420135, Accuracy: 94.99166870117188



由于执行的损失缩放计算,不建议使用 tf.keras.metrics.Mean 来跟踪不同副本的训练损失。


  • 两个副本
  • 在每个副本上处理两个例子
  • 产生的损失值:每个副本为[2,3]和[4,5]
  • 全局批次大小 = 4

通过损失缩放,您可以通过添加损失值来计算每个副本上的每个样本的损失值,然后除以全局批量大小。 在这种情况下:(2 + 3)/ 4 = 1.25(4 + 5)/ 4 = 2.25

如果使用 tf.keras.metrics.Mean 来跟踪两个副本的损失,结果会有所不同。在此示例中,您最终会得到一个 total 为 3.50 和 count 为 2 的结果,在指标上调用 result() 时,您将得到 total/count = 1.75。使用 tf.keras.Metrics 计算的损失将按等于同步副本数的附加因子进行缩放。



