トレーニングのチェックポイント

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

「TensorFlow のモデルを保存する」という言いまわしは通常、次の 2 つのいずれかを意味します。

  1. チェックポイント、または
  2. 保存されたモデル(SavedModel)

チェックポイントは、モデルで使用されるすべてのパラメータ(tf.Variableオブジェクト)の正確な値をキャプチャします。チェックポイントにはモデルで定義された計算のいかなる記述も含まれていないため、通常は、保存されたパラメータ値を使用するソースコードが利用可能な場合に限り有用です。

一方、SavedModel 形式には、パラメータ値(チェックポイント)に加え、モデルで定義された計算のシリアライズされた記述が含まれています。この形式のモデルは、モデルを作成したソースコードから独立しています。したがって、TensorFlow Serving、TensorFlow Lite、TensorFlow.js、または他のプログラミング言語のプログラム(C、C++、Java、Go、Rust、C# などの TensorFlow API)を介したデプロイに適しています。

このガイドでは、チェックポイントの書き込みと読み取りを行う API について説明します。

セットアップ

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

tf.kerasトレーニング API から保存する

tf.kerasの保存と復元に関するガイドをご覧ください。

tf.keras.Model.save_weightsで TensorFlow チェックポイントを保存します。

net.save_weights('easy_checkpoint')

チェックポイントを記述する

TensorFlow モデルの永続的な状態は、tf.Variableオブジェクトに格納されます。これらは直接作成できますが、多くの場合はtf.keras.layerstf.keras.Modelなどの高レベル API を介して作成されます。

変数を管理する最も簡単な方法は、変数を Python オブジェクトにアタッチし、それらのオブジェクトを参照することです。

tf.train.Checkpointtf.keras.layers.Layerおよびtf.keras.Modelのサブクラスは、属性に割り当てられた変数を自動的に追跡します。以下の例では、単純な線形モデルを作成し、モデルのすべての変数の値を含むチェックポイントを記述します。

Model.save_weightsで、モデルチェックポイントを簡単に保存できます。

手動チェックポイント

セットアップ

tf.train.Checkpointのすべての機能を実演するために、トイデータセットと最適化ステップを次のように定義します。

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

チェックポイントオブジェクトを作成する

チェックポイントを手動で作成するには、tf.train.Checkpointオブジェクトが必要です。チェックポイントするオブジェクトの場所は、オブジェクトの属性として設定します。

tf.train.CheckpointManagerは、複数のチェックポイントの管理にも役立ちます。

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

モデルをトレーニングおよびチェックポイントする

次のトレーニングループは、モデルとオプティマイザのインスタンスを作成し、それらをtf.train.Checkpointオブジェクトに集めます。それはデータの各バッチのループ内でトレーニングステップを呼び出し、定期的にチェックポイントをディスクに書き込みます。

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 26.91
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 20.32
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 13.76
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 7.35
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 2.48

復元してトレーニングを続ける

最初の実行後、新しいモデルとマネジャーを渡すことができますが、トレーニングをやめた所からトレーニングを再開します。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.38
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 0.95
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.44
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.35
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.30

tf.train.CheckpointManagerオブジェクトは古いチェックポイントを削除します。上記では、最新の 3 つのチェックポイントのみを保持するように構成されています。

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

これらのパス、例えば'./tf_ckpts/ckpt-10'などは、ディスク上のファイルではなく、indexファイルのプレフィックスで、変数値を含む 1 つまたはそれ以上のデータファイルです。これらのプレフィックスは、まとめて単一のcheckpointファイル('./tf_ckpts/checkpoint')にグループ化され、CheckpointManagerがその状態を保存します。

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

読み込みの仕組み

TensorFlowは、読み込まれたオブジェクトから始めて、名前付きエッジを持つ有向グラフを走査することにより、変数をチェックポイントされた値に合わせます。エッジ名は通常、オブジェクトの属性名に由来しており、self.l1 = tf.keras.layers.Dense(5)"l1"などがその例です。tf.train.Checkpointは、tf.train.Checkpoint(step=...)"step"のように、キーワード引数名を使用します。

上記の例の依存関係グラフは次のようになります。

Visualization of the dependency graph for the example training loop

オプティマイザは赤色、通常変数は青色、オプティマイザスロット変数はオレンジ色です。他のノード、例えばtf.train.Checkpointを表すものは黒色です。

スロット変数はオプティマイザの状態の一部ですが、特定の変数のために作成されます。例えば、上記の'm'エッジはモメンタムに対応し、Adam オプティマイザが各変数のために追跡します。スロット変数は変数とオプティマイザの両方が保存される場合に限りチェックポイントに保存されるので、破線のエッジです。

tf.train.Checkpointオブジェクト上でのrestore()呼び出しは、要求された復元をキューに入れ、Checkpointオブジェクトから一致するパスがあるとすぐに変数値を復元します。例えば、ネットワークとレイヤーを介してそれへのパスを 1 つ再構築することにより、上記で定義したモデルからバイアスのみを読み込むことができます。

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # We get the restored value now
[0. 0. 0. 0. 0.]
[1.8858831 1.9214293 2.8519926 2.9979987 5.1035223]

これらの新しいオブジェクトの依存関係グラフは、上で書いたより大きなチェックポイントの遥かに小さなサブグラフです。 これには、バイアスとtf.train.Checkpointがチェックポイントに番号付けするために使用した保存カウンタのみを含みます。

Visualization of a subgraph for the bias variable

restore()は、オプションのアサーションを持つ状態オブジェクトを返します。新しいCheckpointで作成したすべてのオブジェクトが復元され、status.assert_existing_objects_matched()を渡します。

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7fdef047eba8>

チェックポイントには、層のカーネルやオプティマイザの変数など、一致しない多くのオブジェクトがあります。status.assert_consumed()は、チェックポイントとプログラムが正確に一致する場合に限り渡すため、ここでは例外をスローします。

復元遅延(Delayed restoration)

TensorFlow のLayerオブジェクトは、入力形状が利用可能な場合、最初の呼び出しまで変数の作成を遅らせる可能性があります。例えば、Denseレイヤーのカーネルの形状はレイヤーの入力形状と出力形状の両方に依存するため、コンストラクタ引数として必要な出力形状は、単独で変数を作成するために充分な情報ではありません。Layerの呼び出しは変数の値も読み取るため、復元は変数の作成とその最初の使用の間で発生する必要があります。

このイディオムをサポートするために、tf.train.Checkpointは一致する変数をまだ持たない復元をキューに入れます。

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.645149  4.866288  4.8621855 5.02036   4.9210296]]

チェックポイントを手動で検査する

tf.train.list_variablesは、チェックポイントキーとチェックポイント内の変数の形状をリスト表示します。チェックポイントキーは上で示したグラフのパスです。

tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))
[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('iterator/.ATTRIBUTES/ITERATOR_STATE', [1]),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

リストとディクショナリを追跡する

self.l1 = tf.keras.layers.Dense(5)のような直接の属性割り当てと同様に、リストとディクショナリを属性に割り当てると、それらの内容を追跡します。

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

リストとディクショナリのラッパーオブジェクトにお気づきでしょうか。これらのラッパーは基礎的なデータ構造のチェックポイント可能なバージョンです。属性に基づく読み込みと同様に、これらのラッパーは変数の値がコンテナに追加されるとすぐにそれを復元します。

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

同じ追跡がtf.keras.Modelのサブクラスに自動的に適用され、例えばレイヤーのリストの追跡にも使用される可能性があります。

Estimator でオブジェクトベースのチェックポイントを保存する

Estimator のガイドをご覧ください。

Estimator はデフォルトで、前のセクションで説明したオブジェクトグラフではなく、変数名でチェックポイントを保存します。tf.train.Checkpointは名前ベースのチェックポイントを受け取りますが、モデルの一部を Estimator のmodel_fnの外側に移動すると変数名が変わることがあります。 オブジェクトベースのチェックポイントを保存すると、Estimator の内側でモデルをトレーニングし、外側でそれを使用することが容易になります。

import tensorflow.compat.v1 as tf_compat
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.4524446, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 36.07061.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7fdf5cdb87b8>

その後、tf.train.Checkpointは Estimator のチェックポイントをそのmodel_dirから読み込むことができます。

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

まとめ

TensorFlow オブジェクトは、それらが使用する変数の値を保存および復元するための容易で自動的な仕組みを提供します。