TensorFlow 1 のコードを TensorFlow 2 に移行する

TensorFlow.org で表示 Google Colab で実行 GitHub でソースを表示 ノートブックをダウンロード

本ドキュメントは、低レベル TensorFlow API のユーザーを対象としています。高レベル API(tf.keras)をご使用の場合は、コードを TensorFlow 2.x と完全互換にするためのアクションはほとんどまたはまったくありません。

TensorFlow 2.x で 1.X のコードを未修正で実行することは、(contrib を除き)依然として可能です。

import tensorflow.compat.v1 as tf tf.disable_v2_behavior()

しかし、これでは TensorFlow 2.0 で追加された改善の多くを活用できません。このガイドでは、コードのアップグレード、さらなる単純化、パフォーマンス向上、そしてより容易なメンテナンスについて説明します。

自動変換スクリプト

このドキュメントで説明される変更を実装する前に行うべき最初のステップは、アップグレードスクリプトを実行してみることです。

これはコードを TensorFlow 2.x にアップグレードする際の初期パスとしては十分ですが、v2 特有のコードに変換するわけではありません。コードは依然として tf.compat.v1 エンドポイントを使用して、プレースホルダー、セッション、コレクション、その他 1.x スタイルの機能へのアクセスが可能です。

トップレベルの動作の変更

tf.compat.v1.disable_v2_behavior() を使用することで TensorFlow 2.x でコードが機能する場合でも、対処すべきグローバルな動作の変更があります。主な変更点は次のとおりです。

  • Eager execution、v1.enable_eager_execution(): 暗黙的に tf.Graph を使用するコードは失敗します。このコードは必ず with tf.Graph().as_default() コンテキストでラップしてください。

  • リソース変数、v1.enable_resource_variables(): 一部のコードは、TensorFlow 参照変数によって有効化される非決定的な動作に依存する場合があります。 リソース変数は書き込み中にロックされるため、より直感的な一貫性を保証します。

    • これによりエッジケースでの動作が変わる場合があります。
    • これにより余分なコピーが作成されるため、メモリ使用量が増える可能性があります。
    • これを無効にするには、use_resource=Falsetf.Variable コンストラクタに渡します。
  • テンソルの形状、v1.enable_v2_tensorshape(): TensorFlow 2.x は、テンソルの形状の動作を簡略化されており、t.shape[0].value の代わりに t.shape[0] とすることができます。簡単な変更なので、すぐに修正しておくことをお勧めします。例については TensorShape をご覧ください。

  • 制御フロー、v1.enable_control_flow_v2(): TensorFlow 2.x 制御フローの実装が簡略化されたため、さまざまなグラフ表現を生成します。問題が生じた場合には、バグを報告してください。

TensorFlow 2.x のコードを作成する

このガイドでは、TensorFlow 1.x のコードを TensorFlow 2.x に変換するいくつかの例を確認します。これらの変更によって、コードがパフォーマンスの最適化および簡略化された API 呼び出しを活用できるようになります。

それぞれのケースのパターンは次のとおりです。

1. v1.Session.run 呼び出しを置き換える

すべての v1.Session.run 呼び出しは、Python 関数で置き換える必要があります。

  • feed_dictおよびv1.placeholderは関数の引数になります。
  • fetch は関数の戻り値になります。
  • Eager execution では、pdb などの標準的な Python ツールを使用して、変換中に簡単にデバッグできます。

次に、tf.function デコレータを追加して、グラフで効率的に実行できるようにします。 この機能についての詳細は、AutoGraph ガイドをご覧ください。

注意点:

  • v1.Session.run とは異なり、tf.function は固定のリターンシグネチャを持ち、常にすべての出力を返します。これによってパフォーマンスの問題が生じる場合は、2 つの個別の関数を作成します。

  • tf.control_dependencies または同様の演算は必要ありません。tf.function は、記述された順序で実行されたかのように動作します。たとえば、tf.Variable 割り当てと tf.assert は自動的に実行されます。

「モデルを変換する」セクションには、この変換プロセスの実際の例が含まれています。

2. Python オブジェクトを変数と損失の追跡に使用する

TensorFlow 2.x では、いかなる名前ベースの変数追跡もまったく推奨されていません。 変数の追跡には Python オブジェクトを使用します。

v1.get_variable の代わりに tf.Variable を使用してください。

すべてのv1.variable_scopeは Python オブジェクトに変換が可能です。通常は次のうちの 1 つになります。

tf.Graph.get_collection(tf.GraphKeys.VARIABLES) などの変数のリストを集める必要がある場合には、Layer および Model オブジェクトの .variables.trainable_variables 属性を使用します。

これら Layer クラスと Model クラスは、グローバルコレクションの必要性を除去した別のプロパティを幾つか実装します。.losses プロパティは、tf.GraphKeys.LOSSES コレクション使用の置き換えとなります。

詳細は Keras ガイドをご覧ください。

警告 : 多くの tf.compat.v1 シンボルはグローバルコレクションを暗黙的に使用しています。

3. トレーニングループをアップグレードする

ご利用のユースケースで動作する最高レベルの API を使用してください。独自のトレーニングループを構築するよりも tf.keras.Model.fit の選択を推奨します。

これらの高レベル関数は、独自のトレーニングループを書く場合に見落とされやすい多くの低レベル詳細を管理します。例えば、それらは自動的に正則化損失を集めて、モデルを呼び出す時にtraining=True引数を設定します。

4. データ入力パイプラインをアップグレードする

データ入力には tf.data データセットを使用してください。それらのオブジェクトは効率的で、表現力があり、TensorFlow とうまく統合します。

次のように、tf.keras.Model.fit メソッドに直接渡すことができます。

model.fit(dataset, epochs=5)

また、標準的な Python で直接にイテレートすることもできます。

for example_batch, label_batch in dataset:     break

5. compat.v1シンボルを移行する

tf.compat.v1モジュールには、元のセマンティクスを持つ完全な TensorFlow 1.x API が含まれています。

TensorFlow 2 アップグレードスクリプトは、変換が安全な場合、つまり v2 バージョンの動作が完全に同等であると判断できる場合は、シンボルを 2.0 と同等のものに変換します。(たとえば、これらは同じ関数なので、v1.arg_max の名前を tf.argmax に変更します。)

コードの一部を使用してアップグレードスクリプトを実行した後に、compat.v1 が頻出する可能性があります。 コードを調べ、それらを手動で同等の v2 のコードに変換する価値はあります。(該当するものがある場合には、ログに表示されているはずです。)

モデルを変換する

低レベル変数 & 演算子実行

低レベル API の使用例を以下に示します。

  • 変数スコープを使用して再利用を制御する。

  • v1.get_variableで変数を作成する。

  • コレクションに明示的にアクセスする。

  • 次のようなメソッドでコレクションに暗黙的にアクセスする。

  • v1.placeholder を使用してグラフ入力のセットアップをする。

  • Session.runでグラフを実行する。

  • 変数を手動で初期化する。

変換前

TensorFlow 1.x を使用したコードでは、これらのパターンは以下のように表示されます。

import tensorflow as tf
import tensorflow.compat.v1 as v1

import tensorflow_datasets as tfds
2022-08-09 06:00:37.978781: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-09 06:00:38.654323: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 06:00:38.654568: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 06:00:38.654580: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
g = v1.Graph()

with g.as_default():
  in_a = v1.placeholder(dtype=v1.float32, shape=(2))
  in_b = v1.placeholder(dtype=v1.float32, shape=(2))

  def forward(x):
    with v1.variable_scope("matmul", reuse=v1.AUTO_REUSE):
      W = v1.get_variable("W", initializer=v1.ones(shape=(2,2)),
                          regularizer=lambda x:tf.reduce_mean(x**2))
      b = v1.get_variable("b", initializer=v1.zeros(shape=(2)))
      return W * x + b

  out_a = forward(in_a)
  out_b = forward(in_b)
  reg_loss=v1.losses.get_regularization_loss(scope="matmul")

with v1.Session(graph=g) as sess:
  sess.run(v1.global_variables_initializer())
  outs = sess.run([out_a, out_b, reg_loss],
                feed_dict={in_a: [1, 0], in_b: [0, 1]})

print(outs[0])
print()
print(outs[1])
print()
print(outs[2])
[[1. 0.]
 [1. 0.]]

[[0. 1.]
 [0. 1.]]

1.0

変換後

変換されたコードでは :

  • 変数はローカル Python オブジェクトです。
  • forward関数は依然として計算を定義します。
  • Session.run呼び出しはforwardへの呼び出しに置き換えられます。
  • パフォーマンス向上のためにオプションでtf.functionデコレータを追加可能です。
  • どのグローバルコレクションも参照せず、正則化は手動で計算されます。
  • セッションやプレースホルダーはありません。
W = tf.Variable(tf.ones(shape=(2,2)), name="W")
b = tf.Variable(tf.zeros(shape=(2)), name="b")

@tf.function
def forward(x):
  return W * x + b

out_a = forward([1,0])
print(out_a)
tf.Tensor(
[[1. 0.]
 [1. 0.]], shape=(2, 2), dtype=float32)
out_b = forward([0,1])

regularizer = tf.keras.regularizers.l2(0.04)
reg_loss=regularizer(W)

tf.layersベースのモデル

v1.layersモジュールは、変数を定義および再利用するv1.variable_scopeに依存するレイヤー関数を含めるために使用されます。

変換前

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    x = v1.layers.conv2d(x, 32, 3, activation=v1.nn.relu,
          kernel_regularizer=lambda x:0.004*tf.reduce_mean(x**2))
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    x = v1.layers.dropout(x, 0.1, training=training)
    x = v1.layers.dense(x, 64, activation=v1.nn.relu)
    x = v1.layers.batch_normalization(x, training=training)
    x = v1.layers.dense(x, 10)
    return x
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

print(train_out)
print()
print(test_out)
/tmpfs/tmp/ipykernel_85943/2186903737.py:3: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  x = v1.layers.conv2d(x, 32, 3, activation=v1.nn.relu,
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)

tf.Tensor(
[[-0.06471647 -0.10887557  0.2745164  -0.3928661  -0.04384201  0.14477235
  -0.60257095  0.22861081  0.02734172  0.24904892]], shape=(1, 10), dtype=float32)
/tmpfs/tmp/ipykernel_85943/2186903737.py:5: UserWarning: `tf.layers.max_pooling2d` is deprecated and will be removed in a future version. Please use `tf.keras.layers.MaxPooling2D` instead.
  x = v1.layers.max_pooling2d(x, (2, 2), 1)
/tmpfs/tmp/ipykernel_85943/2186903737.py:6: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  x = v1.layers.flatten(x)
/tmpfs/tmp/ipykernel_85943/2186903737.py:7: UserWarning: `tf.layers.dropout` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dropout` instead.
  x = v1.layers.dropout(x, 0.1, training=training)
/tmpfs/tmp/ipykernel_85943/2186903737.py:8: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  x = v1.layers.dense(x, 64, activation=v1.nn.relu)
/tmpfs/tmp/ipykernel_85943/2186903737.py:9: UserWarning: `tf.layers.batch_normalization` is deprecated and will be removed in a future version. Please use `tf.keras.layers.BatchNormalization` instead. In particular, `tf.control_dependencies(tf.GraphKeys.UPDATE_OPS)` should not be used (consult the `tf.keras.layers.BatchNormalization` documentation).
  x = v1.layers.batch_normalization(x, training=training)
/tmpfs/tmp/ipykernel_85943/2186903737.py:10: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead.
  x = v1.layers.dense(x, 10)

変換後

ほとんどの引数はそのままです。しかし、以下の点は異なります。

  • training引数は、それが実行される時点でモデルによって各レイヤーに渡されます。
  • 元のmodel関数への最初の引数(入力 x)はなくなりました。これはオブジェクトレイヤーがモデルの呼び出しからモデルの構築を分離するためです。

また以下にも注意してください。

  • tf.contribからの初期化子の正則化子を使用している場合は、他よりも多くの引数変更があります。
  • コードはコレクションに書き込みを行わないため、v1.losses.get_regularization_lossなどの関数はそれらの値を返さなくなり、トレーニングループが壊れる可能性があります。
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.04),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))
train_out = model(train_data, training=True)
print(train_out)
tf.Tensor([[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(1, 10), dtype=float32)
test_out = model(test_data, training=False)
print(test_out)
tf.Tensor(
[[-0.06694235  0.24037899  0.04264311  0.21153995 -0.14116539 -0.02983543
  -0.21799125 -0.04154703 -0.15375452 -0.11764101]], shape=(1, 10), dtype=float32)
# Here are all the trainable variables.
len(model.trainable_variables)
8
# Here is the regularization loss.
model.losses
[<tf.Tensor: shape=(), dtype=float32, numpy=0.08283685>]

変数とv1.layersの混在

既存のコードは低レベルの TensorFlow 1.x 変数と演算子に高レベルのv1.layersが混ざっていることがよくあります。

変換前

def model(x, training, scope='model'):
  with v1.variable_scope(scope, reuse=v1.AUTO_REUSE):
    W = v1.get_variable(
      "W", dtype=v1.float32,
      initializer=v1.ones(shape=x.shape),
      regularizer=lambda x:0.004*tf.reduce_mean(x**2),
      trainable=True)
    if training:
      x = x + W
    else:
      x = x + W * 0.5
    x = v1.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
    x = v1.layers.max_pooling2d(x, (2, 2), 1)
    x = v1.layers.flatten(x)
    return x

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)
/tmpfs/tmp/ipykernel_85943/1266778715.py:12: UserWarning: `tf.layers.conv2d` is deprecated and will be removed in a future version. Please Use `tf.keras.layers.Conv2D` instead.
  x = v1.layers.conv2d(x, 32, 3, activation=tf.nn.relu)
/tmpfs/tmp/ipykernel_85943/1266778715.py:13: UserWarning: `tf.layers.max_pooling2d` is deprecated and will be removed in a future version. Please use `tf.keras.layers.MaxPooling2D` instead.
  x = v1.layers.max_pooling2d(x, (2, 2), 1)
/tmpfs/tmp/ipykernel_85943/1266778715.py:14: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead.
  x = v1.layers.flatten(x)

変換後

このコードを変換するには、前の例で示したレイヤーからレイヤーへのマッピングのパターンに従います。

一般的なパターンは次の通りです。

  • __init__でレイヤーパラメータを収集する。
  • buildで変数を構築する。
  • callで計算を実行し、結果を返す。

v1.variable_scopeは事実上それ自身のレイヤーです。従ってtf.keras.layers.Layerとして書き直します。詳細はガイドをご覧ください。

# Create a custom layer for part of the model
class CustomLayer(tf.keras.layers.Layer):
  def __init__(self, *args, **kwargs):
    super(CustomLayer, self).__init__(*args, **kwargs)

  def build(self, input_shape):
    self.w = self.add_weight(
        shape=input_shape[1:],
        dtype=tf.float32,
        initializer=tf.keras.initializers.ones(),
        regularizer=tf.keras.regularizers.l2(0.02),
        trainable=True)

  # Call method will sometimes get used in graph mode,
  # training will get turned into a tensor
  @tf.function
  def call(self, inputs, training=None):
    if training:
      return inputs + self.w
    else:
      return inputs + self.w * 0.5
custom_layer = CustomLayer()
print(custom_layer([1]).numpy())
print(custom_layer([1], training=True).numpy())
[1.5]
[2.]
train_data = tf.ones(shape=(1, 28, 28, 1))
test_data = tf.ones(shape=(1, 28, 28, 1))

# Build the model including the custom layer
model = tf.keras.Sequential([
    CustomLayer(input_shape=(28, 28, 1)),
    tf.keras.layers.Conv2D(32, 3, activation='relu'),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
])

train_out = model(train_data, training=True)
test_out = model(test_data, training=False)

注意点:

  • サブクラス化された Keras モデルとレイヤーは v1 グラフ(自動制御依存性なし)と eager モードの両方で実行される必要があります。

    • call()tf.function()にラップして、AutoGraph と自動制御依存性を得るようにします。
  • training引数を受け取ってcallすることを忘れないようにしてください。

    • それはtf.Tensorである場合があります。
    • それは Python ブール型である場合があります。
  • self.add_weight()を使用して、コンストラクタまたはModel.buildでモデル変数を作成します。

    • Model.buildでは、入力形状にアクセスできるため、適合する形状で重みを作成できます。
    • tf.keras.layers.Layer.add_weightを使用すると、Keras が変数と正則化損失を追跡できるようになります。
  • オブジェクトにtf.Tensorsを保持してはいけません。

    • それらはtf.functionまたは eager コンテキスト内のいずれかで作成される可能性がありますが、それらのテンソルは異なる振る舞いをします。
    • 状態にはtf.Variableを使用してください。これは常に両方のコンテキストから使用可能です。
    • tf.Tensorsは中間値専用です。

Slim & contrib.layers に関する注意

古い TensorFlow 1.x コードの大部分は Slim ライブラリを使用しており、これはtf.contrib.layersとして TensorFlow 1.x でパッケージ化されていました。 contribモジュールに関しては、TensorFlow 2.x ではtf.compat.v1内でも、あっても利用できなくなりました。Slim を使用したコードの TensorFlow 2.x への変換は、v1.layersを使用したレポジトリの変換よりも複雑です。現実的には、まず最初に Slim コードをv1.layersに変換してから Keras に変換するほうが賢明かもしれません。

  • arg_scopesを除去します。すべての引数は明示的である必要があります。
  • それらを使用する場合、 normalizer_fnactivation_fnをそれら自身のレイヤーに分割します。
  • 分離可能な畳み込みレイヤーは 1 つまたはそれ以上の異なる Keras レイヤー(深さ的な、ポイント的な、分離可能な Keras レイヤー)にマップします。
  • Slim とv1.layersには異なる引数名とデフォルト値があります。
  • 一部の引数には異なるスケールがあります。
  • Slim 事前トレーニング済みモデルを使用する場合は、tf.keras.applicationsから Keras 事前トレーニング済みモデル、または元の Slim コードからエクスポートされた TensorFlow ハブの TensorFlow 2 SavedModel をお試しください。

一部のtf.contribレイヤーはコアの TensorFlow に移動されていない可能性がありますが、代わりに TensorFlow アドオンパッケージに移動されています。

トレーニング

tf.kerasモデルにデータを供給する方法は沢山あります。それらは Python ジェネレータと Numpy 配列を入力として受け取ります。

モデルへのデータ供給方法として推奨するのは、データ操作用の高パフォーマンスクラスのコレクションを含むtf.dataパッケージの使用です。

依然としてtf.queueを使用している場合、これらは入力パイプラインとしてではなく、データ構造としてのみサポートされます。

データセットを使用する

TensorFlow Dataset パッケージ(tfds)には、事前定義されたデータセットをtf.data.Datasetオブジェクトとして読み込むためのユーティリティが含まれています。

この例として、tfdsを使用して MNISTdataset を読み込んでみましょう。

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

次に、トレーニング用のデータを準備します。

  • 各画像をリスケールする。
  • 例の順序をシャッフルする。
  • 画像とラベルのバッチを集める。
BUFFER_SIZE = 10 # Use a much larger value for real code.
BATCH_SIZE = 64
NUM_EPOCHS = 5


def scale(image, label):
  image = tf.cast(image, tf.float32)
  image /= 255

  return image, label

例を短く保つために、データセットをトリミングして 5 バッチのみを返すようにします。

train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)

STEPS_PER_EPOCH = 5

train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2022-08-09 06:00:45.726959: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Keras トレーニングループを使用する

トレーニングプロセスの低レベル制御が不要な場合は、Keras 組み込みのfitevaluatepredictメソッドの使用が推奨されます。これらのメソッドは(シーケンシャル、関数型、またはサブクラス化)実装を問わず、モデルをトレーニングするための統一インターフェースを提供します。

これらのメソッドには次のような優位点があります。

  • Numpy 配列、Python ジェネレータ、tf.data.Datasetsを受け取ります。
  • 正則化と活性化損失を自動的に適用します。
  • マルチデバイストレーニングのためにtf.distributeをサポートします。
  • 任意の callable は損失とメトリクスとしてサポートします。
  • tf.keras.callbacks.TensorBoardのようなコールバックとカスタムコールバックをサポートします。
  • 自動的に TensorFlow グラフを使用し、高性能です。

ここにDatasetを使用したモデルのトレーニング例を示します。(この機能ついての詳細はチュートリアルをご覧ください。)

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)

print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 1s 6ms/step - loss: 1.4698 - accuracy: 0.5562
Epoch 2/5
2022-08-09 06:00:46.752921: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 4ms/step - loss: 0.4207 - accuracy: 0.9094
Epoch 3/5
2022-08-09 06:00:47.058149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 4ms/step - loss: 0.2877 - accuracy: 0.9531
Epoch 4/5
2022-08-09 06:00:47.305885: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 4ms/step - loss: 0.2084 - accuracy: 0.9781
Epoch 5/5
2022-08-09 06:00:47.605169: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 4ms/step - loss: 0.1614 - accuracy: 0.9906
2022-08-09 06:00:47.832533: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
5/5 [==============================] - 0s 3ms/step - loss: 1.5770 - accuracy: 0.5500
Loss 1.5769875049591064, Accuracy 0.550000011920929
2022-08-09 06:00:48.255285: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

ループを自分で書く

Keras モデルのトレーニングステップは動作していても、そのステップの外でより制御が必要な場合は、データ イテレーション ループでtf.keras.Model.train_on_batchメソッドの使用を検討してみてください。

tf.keras.callbacks.Callbackとして、多くのものが実装可能であることに留意してください。

このメソッドには前のセクションで言及したメソッドの優位点の多くがありますが、外側のループのユーザー制御も与えます。

tf.keras.Model.test_on_batchまたはtf.keras.Model.evaluateを使用して、トレーニング中のパフォーマンスをチェックすることも可能です。

注意: train_on_batchtest_on_batchは、デフォルトで単一バッチの損失とメトリクスを返します。reset_metrics=Falseを渡すと累積メトリックを返しますが、必ずメトリックアキュムレータを適切にリセットすることを忘れないようにしてくだい。また、AUCのような一部のメトリクスは正しく計算するためにreset_metrics=Falseが必要なことも覚えておいてください。

上のモデルのトレーニングを続けます。

# Model is the full model w/o custom layers
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

for epoch in range(NUM_EPOCHS):
  #Reset the metric accumulators
  model.reset_metrics()

  for image_batch, label_batch in train_data:
    result = model.train_on_batch(image_batch, label_batch)
    metrics_names = model.metrics_names
    print("train: ",
          "{}: {:.3f}".format(metrics_names[0], result[0]),
          "{}: {:.3f}".format(metrics_names[1], result[1]))
  for image_batch, label_batch in test_data:
    result = model.test_on_batch(image_batch, label_batch,
                                 # return accumulated metrics
                                 reset_metrics=False)
  metrics_names = model.metrics_names
  print("\neval: ",
        "{}: {:.3f}".format(metrics_names[0], result[0]),
        "{}: {:.3f}".format(metrics_names[1], result[1]))
train:  loss: 0.156 accuracy: 1.000
train:  loss: 0.161 accuracy: 0.969
train:  loss: 0.142 accuracy: 0.984
train:  loss: 0.186 accuracy: 0.953
train:  loss: 0.232 accuracy: 0.953
2022-08-09 06:00:49.067232: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-08-09 06:00:49.369474: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
eval:  loss: 1.332 accuracy: 0.695
train:  loss: 0.084 accuracy: 1.000
train:  loss: 0.106 accuracy: 1.000
train:  loss: 0.083 accuracy: 1.000
train:  loss: 0.111 accuracy: 1.000
train:  loss: 0.104 accuracy: 1.000
2022-08-09 06:00:49.627975: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-08-09 06:00:49.831268: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
eval:  loss: 1.248 accuracy: 0.779
train:  loss: 0.074 accuracy: 1.000
train:  loss: 0.081 accuracy: 1.000
train:  loss: 0.073 accuracy: 1.000
train:  loss: 0.085 accuracy: 1.000
train:  loss: 0.070 accuracy: 1.000
2022-08-09 06:00:50.075280: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-08-09 06:00:50.262077: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
eval:  loss: 1.214 accuracy: 0.771
train:  loss: 0.060 accuracy: 1.000
train:  loss: 0.067 accuracy: 1.000
train:  loss: 0.058 accuracy: 1.000
train:  loss: 0.060 accuracy: 1.000
train:  loss: 0.061 accuracy: 1.000
2022-08-09 06:00:50.540956: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-08-09 06:00:50.718501: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
eval:  loss: 1.203 accuracy: 0.786
train:  loss: 0.057 accuracy: 1.000
train:  loss: 0.058 accuracy: 1.000
train:  loss: 0.053 accuracy: 1.000
train:  loss: 0.049 accuracy: 1.000
train:  loss: 0.055 accuracy: 1.000
2022-08-09 06:00:50.959857: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
eval:  loss: 1.197 accuracy: 0.781
2022-08-09 06:00:51.168042: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

トレーニングステップをカスタマイズする

より多くの柔軟性と制御を必要とする場合、独自のトレーニングループを実装することでそれが可能になります。以下の 3 つのステップを踏みます。

  1. Python ジェネレータかtf.data.Datasetをイテレートして例のバッチを作成します。
  2. tf.GradientTapeを使用して勾配を集めます。
  3. tf.keras.optimizersの 1 つを使用して、モデルの変数に重み更新を適用します。

留意点:

  • サブクラス化されたレイヤーとモデルのcallメソッドには、常にtraining引数を含めます。
  • training引数を確実に正しくセットしてモデルを呼び出します。
  • 使用方法によっては、モデルがデータのバッチ上で実行されるまでモデル変数は存在しないかもしれません。
  • モデルの正則化損失などを手動で処理する必要があります。

v1 と比べて簡略化されている点に注意してください :

  • 変数初期化子を実行する必要はありません。作成時に変数は初期化されます。
  • たとえtf.function演算が eager モードで振る舞う場合でも、手動の制御依存性を追加する必要はありません。
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
])

optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

for epoch in range(NUM_EPOCHS):
  for inputs, labels in train_data:
    train_step(inputs, labels)
  print("Finished epoch", epoch)
2022-08-09 06:00:51.964439: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 0
2022-08-09 06:00:52.570321: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 1
2022-08-09 06:00:52.900317: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 2
2022-08-09 06:00:53.103914: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Finished epoch 3
Finished epoch 4
2022-08-09 06:00:53.318230: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

新しいスタイルのメトリクスと損失

TensorFlow 2.x では、メトリクスと損失はオブジェクトです。Eager で実行的にtf.function内で動作します。

損失オブジェクトは呼び出し可能で、(y_true, y_pred) を引数として期待します。

cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815

メトリックオブジェクトには次のメソッドがあります 。

  • Metric.update_state() — 新しい観測を追加する
  • Metric.result() — 観測値が与えられたとき、メトリックの現在の結果を得る
  • Metric.reset_states() — すべての観測をクリアする

オブジェクト自体は呼び出し可能です。呼び出しはupdate_stateと同様に新しい観測の状態を更新し、メトリクスの新しい結果を返します。

メトリックの変数を手動で初期化する必要はありません。また、TensorFlow 2.x は自動制御依存性を持つため、それらについても気にする必要はありません。

次のコードは、メトリックを使用してカスタムトレーニングループ内で観測される平均損失を追跡します。

# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

@tf.function
def train_step(inputs, labels):
  with tf.GradientTape() as tape:
    predictions = model(inputs, training=True)
    regularization_loss=tf.math.add_n(model.losses)
    pred_loss=loss_fn(labels, predictions)
    total_loss=pred_loss + regularization_loss

  gradients = tape.gradient(total_loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))
  # Update the metrics
  loss_metric.update_state(total_loss)
  accuracy_metric.update_state(labels, predictions)


for epoch in range(NUM_EPOCHS):
  # Reset the metrics
  loss_metric.reset_states()
  accuracy_metric.reset_states()

  for inputs, labels in train_data:
    train_step(inputs, labels)
  # Get the metric results
  mean_loss=loss_metric.result()
  mean_accuracy = accuracy_metric.result()

  print('Epoch: ', epoch)
  print('  loss:     {:.3f}'.format(mean_loss))
  print('  accuracy: {:.3f}'.format(mean_accuracy))
2022-08-09 06:00:53.929238: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  0
  loss:     0.149
  accuracy: 0.994
2022-08-09 06:00:54.166225: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  1
  loss:     0.131
  accuracy: 1.000
2022-08-09 06:00:54.481973: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  2
  loss:     0.115
  accuracy: 1.000
2022-08-09 06:00:54.774067: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Epoch:  3
  loss:     0.100
  accuracy: 1.000
Epoch:  4
  loss:     0.090
  accuracy: 1.000
2022-08-09 06:00:54.994654: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Keras メトリック名

TensorFlow 2.x では、Keras モデルはメトリクス名の処理に関してより一貫性があります。

メトリクスリストで文字列を渡すと、まさにその文字列がメトリクスのnameとして使用されます。これらの名前は
model.fitによって返される履歴オブジェクトと、keras.callbacksに渡されるログに表示されます。これはメトリクスリストで渡した文字列に設定されています。

model.compile(
    optimizer = tf.keras.optimizers.Adam(0.001),
    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.1179 - acc: 0.9875 - accuracy: 0.9875 - my_accuracy: 0.9875
2022-08-09 06:00:56.126928: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])

これはmetrics=["accuracy"]を渡すとdict_keys(['loss', 'acc'])になっていた、以前のバージョンとは異なります。

Keras オプティマイザ

v1.train.AdamOptimizerv1.train.GradientDescentOptimizerなどのv1.train内のオプティマイザは、tf.keras.optimizers内に同等のものを持ちます。

v1.trainkeras.optimizersに変換する

オプティマイザを変換する際の注意事項を次に示します。

一部のtf.keras.optimizersの新しいデフォルト

警告: モデルの収束挙動に変化が見られる場合には、デフォルトの学習率を確認してください。

optimizers.SGDoptimizers.Adam、またはoptimizers.RMSpropに変更はありません。

次のデフォルトの学習率が変更されました。

TensorBoard

TensorFlow 2 には、TensorBoard で視覚化するための要約データを記述するために使用されるtf.summary API の大幅な変更が含まれています。新しいtf.summaryの概要については、TensorFlow 2 API を使用した複数のチュートリアルがあります。これには、TensorBoard TensorFlow 2 移行ガイドも含まれています。

保存と読み込み

チェックポイントの互換性

TensorFlow 2.x はオブジェクトベースのチェックポイントを使用します。

古いスタイルの名前ベースのチェックポイントは、注意を払えば依然として読み込むことができます。コード変換プロセスは変数名変更という結果になるかもしれませんが、回避方法はあります。

最も単純なアプローチは、チェックポイント内の名前と新しいモデルの名前を揃えて並べることです。

  • 変数にはすべて依然として設定が可能なname引数があります。
  • Keras モデルはまた name引数を取り、それらの変数のためのプレフィックスとして設定されます。
  • v1.name_scope関数は、変数名のプレフィックスの設定に使用できます。これはtf.variable_scopeとは大きく異なります。これは名前だけに影響するもので、変数と再利用の追跡はしません。

ご利用のユースケースで動作しない場合は、v1.train.init_from_checkpointを試してみてください。これはassignment_map引数を取り、古い名前から新しい名前へのマッピングを指定します。

注意 : 読み込みを遅延できるオブジェクトベースのチェックポイントとは異なり、名前ベースのチェックポイントは関数が呼び出される時に全ての変数が構築されていることを要求します。一部のモデルは、buildを呼び出すかデータのバッチでモデルを実行するまで変数の構築を遅延します。

TensorFlow Estimatorリポジトリには事前作成された Estimator のチェックポイントを TensorFlow 1.X から 2.0 にアップグレードするための変換ツールが含まれています。これは、同様のユースケースのツールを構築する方法の例として有用な場合があります。

保存されたモデルの互換性

保存されたモデルには、互換性に関する重要な考慮事項はありません。

  • TensorFlow 1.x saved_models は TensorFlow 2.x で動作します。
  • TensorFlow 2.x saved_models は全ての演算がサポートされていれば TensorFlow 1.x で動作します。

Graph.pb または Graph.pbtxt

未加工のGraph.pbファイルを TensorFlow 2.x にアップグレードする簡単な方法はありません。確実な方法は、ファイルを生成したコードをアップグレードすることです。

ただし、「凍結グラフ」(変数が定数に変換されたtf.Graph)がある場合、v1.wrap_functionを使用してconcrete_functionへの変換が可能です。

def wrap_frozen_graph(graph_def, inputs, outputs):
  def _imports_graph_def():
    tf.compat.v1.import_graph_def(graph_def, name="")
  wrapped_import = tf.compat.v1.wrap_function(_imports_graph_def, [])
  import_graph = wrapped_import.graph
  return wrapped_import.prune(
      tf.nest.map_structure(import_graph.as_graph_element, inputs),
      tf.nest.map_structure(import_graph.as_graph_element, outputs))

たとえば、次のような凍結された Inception v1 グラフ(2016 年)があります。

path = tf.keras.utils.get_file(
    'inception_v1_2016_08_28_frozen.pb',
    'http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz',
    untar=True)
Downloading data from http://storage.googleapis.com/download.tensorflow.org/models/inception_v1_2016_08_28_frozen.pb.tar.gz
24695710/24695710 [==============================] - 1s 0us/step

tf.GraphDefを読み込みます。

graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(open(path,'rb').read())

これをconcrete_functionにラップします。

inception_func = wrap_frozen_graph(
    graph_def, inputs='input:0',
    outputs='InceptionV1/InceptionV1/Mixed_3b/Branch_1/Conv2d_0a_1x1/Relu:0')

入力としてテンソルを渡します。

input_img = tf.ones([1,224,224,3], dtype=tf.float32)
inception_func(input_img).shape
TensorShape([1, 28, 28, 96])

Estimator

Estimator でトレーニングする

Estimator は TensorFlow 2.0 でサポートされています。

Estimator を使用する際には、TensorFlow 1.x. からのinput_fn()tf.estimator.TrainSpectf.estimator.EvalSpecを使用できます。

ここに train と evaluate specs を伴う input_fn を使用する例があります。

input_fn と train/eval specs を作成する

# Define the estimator's input_fn
def input_fn():
  datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
  mnist_train, mnist_test = datasets['train'], datasets['test']

  BUFFER_SIZE = 10000
  BATCH_SIZE = 64

  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label[..., tf.newaxis]

  train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
  return train_data.repeat()

# Define train &amp; eval specs
train_spec = tf.estimator.TrainSpec(input_fn=input_fn,
                                    max_steps=STEPS_PER_EPOCH * NUM_EPOCHS)
eval_spec = tf.estimator.EvalSpec(input_fn=input_fn,
                                  steps=STEPS_PER_EPOCH)

Keras モデル定義を使用する

TensorFlow 2.x で Estimator を構築する方法には、いくつかの違いがあります。

モデルは Keras を使用して定義することを推奨します。次にtf.keras.estimator.model_to_estimatorユーティリティを使用して、モデルを Estimator に変更します。次のコードは Estimator を作成してトレーニングする際に、このユーティリティをどのように使用するかを示します。

def make_model():
  return tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, 3, activation='relu',
                           kernel_regularizer=tf.keras.regularizers.l2(0.02),
                           input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dropout(0.1),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Dense(10)
  ])
model = make_model()

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

estimator = tf.keras.estimator.model_to_estimator(
  keras_model = model
)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpgzsw8u5z
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpgzsw8u5z
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using the Keras model provided.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py:514: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/backend.py:450: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py:514: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpgzsw8u5z', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpgzsw8u5z', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.util has been moved to tensorflow.python.checkpoint.checkpoint. The old module will be deleted in version 2.11.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.util has been moved to tensorflow.python.checkpoint.checkpoint. The old module will be deleted in version 2.11.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmpgzsw8u5z/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmpgzsw8u5z/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmpgzsw8u5z/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmpgzsw8u5z/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 8 variables.
INFO:tensorflow:Warm-started 8 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpgzsw8u5z/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpgzsw8u5z/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.7048984, step = 0
INFO:tensorflow:loss = 2.7048984, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmpfs/tmp/tmpgzsw8u5z/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmpfs/tmp/tmpgzsw8u5z/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.checkpoint_management has been moved to tensorflow.python.checkpoint.checkpoint_management. The old module will be deleted in version 2.9.
WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.checkpoint_management has been moved to tensorflow.python.checkpoint.checkpoint_management. The old module will be deleted in version 2.9.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/training_v1.py:2045: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  updates = self.state_updates
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-09T06:01:01
INFO:tensorflow:Starting evaluation at 2022-08-09T06:01:01
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpgzsw8u5z/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpgzsw8u5z/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.79957s
2022-08-09 06:01:02.018865: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:tensorflow:Inference Time : 0.79957s
INFO:tensorflow:Finished evaluation at 2022-08-09-06:01:02
INFO:tensorflow:Finished evaluation at 2022-08-09-06:01:02
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.5125, global_step = 25, loss = 1.6212566
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.5125, global_step = 25, loss = 1.6212566
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmpfs/tmp/tmpgzsw8u5z/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmpfs/tmp/tmpgzsw8u5z/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.45869094.
2022-08-09 06:01:02.086642: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:tensorflow:Loss for final step: 0.45869094.
({'accuracy': 0.5125, 'loss': 1.6212566, 'global_step': 25}, [])

注意 : Keras で重み付きメトリクスを作成し、model_to_estimatorを使用してそれらを Estimator API で重み付きメトリクスを変換することはサポートされません。それらのメトリクスは、add_metrics関数を使用して Estimator 仕様で直接作成する必要があります。

カスタム model_fn を使用する

保守する必要がある既存のカスタム Estimator model_fn を持つ場合には、model_fnを変換して Keras モデルを使用できるようにすることが可能です。

しかしながら、互換性の理由から、カスタムmodel_fnは依然として1.x スタイルのグラフモードで動作します。これは eager execution はなく自動制御依存性もないことも意味します。

注意: 長期的には、特にカスタムの model_fn を使って、tf.estimator から移行することを計画する必要があります。代替の API は tf.kerastf.distribute です。トレーニングの一部に Estimator を使用する必要がある場合は、tf.keras.estimator.model_to_estimator コンバータを使用して keras.Model から Estimator を作成する必要があります。

最小限の変更で model_fn をカスタマイズする

TensorFlow 2.0 でカスタムmodel_fnを動作させるには、既存のコードの変更を最小限に留めたい場合、optimizersmetricsなどのtf.compat.v1シンボルを使用することができます。

カスタムmodel_fnで Keras モデルを使用することは、それをカスタムトレーニングループで使用することに類似しています。

  • mode引数を基に、training段階を適切に設定します。
  • モデルのtrainable_variablesをオプティマイザに明示的に渡します。

しかし、カスタムループと比較して、重要な違いがあります。

  • Model.lossesを使用する代わりにModel.get_losses_forを使用して損失を抽出します。
  • Model.get_updates_forを使用してモデルの更新を抽出します。

注意 : 「更新」は各バッチの後にモデルに適用される必要がある変更です。例えば、layers.BatchNormalizationレイヤーの平均と分散の移動平均などです。

次のコードはカスタムmodel_fnから Estimator を作成し、これらの懸念事項をすべて示しています。

def my_model_fn(features, labels, mode):
  model = make_model()

  optimizer = tf.compat.v1.train.AdamOptimizer()
  loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  predictions = model(features, training=training)

  if mode == tf.estimator.ModeKeys.PREDICT:
    return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_fn(labels, predictions) + tf.math.add_n(reg_losses)

  accuracy = tf.compat.v1.metrics.accuracy(labels=labels,
                                           predictions=tf.math.argmax(predictions, axis=1),
                                           name='acc_op')

  update_ops = model.get_updates_for(None) + model.get_updates_for(features)
  minimize_op = optimizer.minimize(
      total_loss,
      var_list=model.trainable_variables,
      global_step=tf.compat.v1.train.get_or_create_global_step())
  train_op = tf.group(minimize_op, update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op, eval_metric_ops={'accuracy': accuracy})

# Create the Estimator &amp; Train
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpgdgezc32
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpgdgezc32
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpgdgezc32', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpgdgezc32', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpgdgezc32/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpgdgezc32/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.7353652, step = 0
INFO:tensorflow:loss = 2.7353652, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmpfs/tmp/tmpgdgezc32/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmpfs/tmp/tmpgdgezc32/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-09T06:01:05
INFO:tensorflow:Starting evaluation at 2022-08-09T06:01:05
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpgdgezc32/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpgdgezc32/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.71139s
2022-08-09 06:01:06.467218: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:tensorflow:Inference Time : 0.71139s
INFO:tensorflow:Finished evaluation at 2022-08-09-06:01:06
INFO:tensorflow:Finished evaluation at 2022-08-09-06:01:06
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.66875, global_step = 25, loss = 1.5850646
INFO:tensorflow:Saving dict for global step 25: accuracy = 0.66875, global_step = 25, loss = 1.5850646
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmpfs/tmp/tmpgdgezc32/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmpfs/tmp/tmpgdgezc32/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.40536988.
2022-08-09 06:01:06.568833: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:tensorflow:Loss for final step: 0.40536988.
({'accuracy': 0.66875, 'loss': 1.5850646, 'global_step': 25}, [])

TensorFlow 2.x シンボルでmodel_fnをカスタマイズする

TensorFlow 1.x シンボルをすべて削除し、カスタムmodel_fn をネイティブの TensorFlow 2.x にアップグレードする場合は、オプティマイザとメトリクスをtf.keras.optimizerstf.keras.metricsにアップグレードする必要があります。

カスタムmodel_fnでは、上記の変更に加えて、さらにアップグレードを行う必要があります。

  • v1.train.Optimizer の代わりに tf.keras.optimizers を使用します。
  • 損失が呼び出し可能(関数など)な場合は、Optimizer.minimize()を使用してtrain_op/minimize_opを取得します。
  • train_op/minimize_opを計算するには、
    • 損失がスカラー損失Tensor(呼び出し不可)の場合は、Optimizer.get_updates()を使用します。返されるリストの最初の要素は目的とするtrain_op/minimize_opです。
    • 損失が呼び出し可能(関数など)な場合は、Optimizer.minimize()を使用してtrain_op/minimize_opを取得します。
  • 評価にはtf.compat.v1.metricsの代わりにtf.keras.metricsを使用します。

上記のmy_model_fnの例では、2.0 シンボルの移行されたコードは次のように表示されます。

def my_model_fn(features, labels, mode):
  model = make_model()

  training = (mode == tf.estimator.ModeKeys.TRAIN)
  loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
  predictions = model(features, training=training)

  # Get both the unconditional losses (the None part)
  # and the input-conditional losses (the features part).
  reg_losses = model.get_losses_for(None) + model.get_losses_for(features)
  total_loss=loss_obj(labels, predictions) + tf.math.add_n(reg_losses)

  # Upgrade to tf.keras.metrics.
  accuracy_obj = tf.keras.metrics.Accuracy(name='acc_obj')
  accuracy = accuracy_obj.update_state(
      y_true=labels, y_pred=tf.math.argmax(predictions, axis=1))

  train_op = None
  if training:
    # Upgrade to tf.keras.optimizers.
    optimizer = tf.keras.optimizers.Adam()
    # Manually assign tf.compat.v1.global_step variable to optimizer.iterations
    # to make tf.compat.v1.train.global_step increased correctly.
    # This assignment is a must for any `tf.train.SessionRunHook` specified in
    # estimator, as SessionRunHooks rely on global step.
    optimizer.iterations = tf.compat.v1.train.get_or_create_global_step()
    # Get both the unconditional updates (the None part)
    # and the input-conditional updates (the features part).
    update_ops = model.get_updates_for(None) + model.get_updates_for(features)
    # Compute the minimize_op.
    minimize_op = optimizer.get_updates(
        total_loss,
        model.trainable_variables)[0]
    train_op = tf.group(minimize_op, *update_ops)

  return tf.estimator.EstimatorSpec(
    mode=mode,
    predictions=predictions,
    loss=total_loss,
    train_op=train_op,
    eval_metric_ops={'Accuracy': accuracy_obj})

# Create the Estimator &amp; Train.
estimator = tf.estimator.Estimator(model_fn=my_model_fn)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp7ua57qc2
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp7ua57qc2
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp7ua57qc2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp7ua57qc2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp7ua57qc2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp7ua57qc2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 2.7497802, step = 0
INFO:tensorflow:loss = 2.7497802, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 25...
INFO:tensorflow:Saving checkpoints for 25 into /tmpfs/tmp/tmp7ua57qc2/model.ckpt.
INFO:tensorflow:Saving checkpoints for 25 into /tmpfs/tmp/tmp7ua57qc2/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 25...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2022-08-09T06:01:09
INFO:tensorflow:Starting evaluation at 2022-08-09T06:01:09
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp7ua57qc2/model.ckpt-25
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp7ua57qc2/model.ckpt-25
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [1/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [2/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [3/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [4/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Evaluation [5/5]
INFO:tensorflow:Inference Time : 0.92612s
2022-08-09 06:01:10.464838: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:tensorflow:Inference Time : 0.92612s
INFO:tensorflow:Finished evaluation at 2022-08-09-06:01:10
INFO:tensorflow:Finished evaluation at 2022-08-09-06:01:10
INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.61875, global_step = 25, loss = 1.6137655
INFO:tensorflow:Saving dict for global step 25: Accuracy = 0.61875, global_step = 25, loss = 1.6137655
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmpfs/tmp/tmp7ua57qc2/model.ckpt-25
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmpfs/tmp/tmp7ua57qc2/model.ckpt-25
INFO:tensorflow:Loss for final step: 0.5460182.
2022-08-09 06:01:10.544548: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
INFO:tensorflow:Loss for final step: 0.5460182.
({'Accuracy': 0.61875, 'loss': 1.6137655, 'global_step': 25}, [])

事前作成された Estimator

tf.estimator.DNN*tf.estimator.Linear*tf.estimator.DNNLinearCombined*のファミリーに含まれる事前作成された Estimator は、依然として TensorFlow 2.0 API でもサポートされていますが、一部の引数が変更されています。

  1. input_layer_partitioner: v2 で削除されました。
  2. loss_reduction: tf.compat.v1.losses.Reductionの代わりにtf.keras.losses.Reductionに更新されました。デフォルト値もtf.compat.v1.losses.Reduction.SUMからtf.keras.losses.Reduction.SUM_OVER_BATCH_SIZEに変更されています。
  3. optimizerdnn_optimizerlinear_optimizer: これらの引数はtf.compat.v1.train.Optimizerの代わりにtf.keras.optimizersに更新されています。

上記の変更を移行するには :

  1. TensorFlow 2.x では配布戦略が自動的に処理するため、input_layer_partitionerの移行は必要ありません。
  2. loss_reductionについてはtf.keras.losses.Reductionでサポートされるオプションを確認してください。
  3. optimizer 引数の場合:
    • 1) optimizerdnn_optimizer、または linear_optimizer 引数を渡さない場合、または 2) optimizer 引数を string としてコードに指定しない場合、デフォルトで tf.keras.optimizers が使用されるため、何も変更する必要はありません。
    • optimizer引数については、optimizerdnn_optimizerlinear_optimizer引数を渡さない場合、またはoptimizer引数をコード内の内のstringとして指定する場合は、何も変更する必要はありません。デフォルトでtf.keras.optimizersを使用します。それ以外の場合は、tf.compat.v1.train.Optimizerから対応するtf.keras.optimizersに更新する必要があります。

チェックポイントコンバータ

tf.keras.optimizersは異なる変数セットを生成してチェックポイントに保存するするため、keras.optimizersへの移行は TensorFlow 1.x を使用して保存されたチェックポイントを壊してしまいます。TensorFlow 2.x への移行後に古いチェックポイントを再利用できるようにするには、チェックポイントコンバータツールをお試しください。

 curl -O https://raw.githubusercontent.com/tensorflow/estimator/master/tensorflow_estimator/python/estimator/tools/checkpoint_converter.py
% Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 14924  100 14924    0     0   108k      0 --:--:-- --:--:-- --:--:--  108k

ツールにはヘルプが組み込まれています。

 python checkpoint_converter.py -h
2022-08-09 06:01:11.531847: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-08-09 06:01:12.200568: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 06:01:12.200795: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvrtc.so.11.1: cannot open shared object file: No such file or directory
2022-08-09 06:01:12.200817: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
usage: checkpoint_converter.py [-h]
                               {dnn,linear,combined} source_checkpoint
                               source_graph target_checkpoint

positional arguments:
  {dnn,linear,combined}
                        The type of estimator to be converted. So far, the
                        checkpoint converter only supports Canned Estimator.
                        So the allowed types include linear, dnn and combined.
  source_checkpoint     Path to source checkpoint file to be read in.
  source_graph          Path to source graph file to be read in.
  target_checkpoint     Path to checkpoint file to be written out.

optional arguments:
  -h, --help            show this help message and exit

TensorShape

このクラスはtf.compat.v1.Dimensionオブジェクトの代わりにintを保持することにより単純化されました。従って、.value()を呼び出してintを取得する必要はありません。

個々のtf.compat.v1.Dimensionオブジェクトは依然としてtf.TensorShape.dimsからアクセス可能です。

以下に TensorFlow 1.x と TensorFlow 2.x 間の違いを示します。

# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape
TensorShape([16, None, 256])

TensorFlow 1.x で次を使っていた場合:

value = shape[i].value

Then do this in TensorFlow 2.x:

value = shape[i]
value
16

TensorFlow 1.x で次を使っていた場合:

for dim in shape:     value = dim.value     print(value)

TensorFlow 2.0 では次のようにします:

for value in shape:
  print(value)
16
None
256

TensorFlow 1.x で次を使っていた場合(またはその他の次元のメソッドを使用していた場合):

dim = shape[i] dim.assert_is_compatible_with(other_dim)

TensorFlow 2.0 では次のようにします:

other_dim = 16
Dimension = tf.compat.v1.Dimension

if shape.rank is None:
  dim = Dimension(None)
else:
  dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
True
shape = tf.TensorShape(None)

if shape:
  dim = shape.dims[i]
  dim.is_compatible_with(other_dim) # or any other dimension method

tf.TensorShape のブール型の値は、階数がわかっている場合は Trueで、そうでない場合はFalseです。

print(bool(tf.TensorShape([])))      # Scalar
print(bool(tf.TensorShape([0])))     # 0-length vector
print(bool(tf.TensorShape([1])))     # 1-length vector
print(bool(tf.TensorShape([None])))  # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100])))       # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None)))  # A tensor with unknown rank.
True
True
True
True
True
True

False

その他の変更点

  • tf.colocate_withを削除する : TensorFlow のデバイス配置アルゴリズムが大幅に改善されたため、これはもう必要ありません。削除したことによってパフォーマンスが低下する場合には、バグを報告してください

  • v1.ConfigProtoの使用をtf.configの同等の関数に置き換える。

まとめ

全体のプロセスは次のとおりです。

  1. アップグレードスクリプトを実行する。
  2. contrib シンボルを除去する。
  3. モデルをオブジェクト指向スタイル (Keras) に切り替える。
  4. 可能なところでは tf.kerasまたはtf.estimatorトレーニングと評価ループを使用する。
  5. そうでない場合はカスタムループを使用してよいが、セッションとコレクションを回避すること。

コードを慣用的な TensorFlow 2.0 に変換するには少し作業を要しますが、変更するごとに次のような結果が得られます。

  • コード行が減少する。
  • 明瞭さと簡略性が向上する。
  • デバッグが容易になる。