TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
TensorFlow コードを TF1.x から TF2 に移行する場合、移行したコードが TF2 でも TF1.x と同じように動作することを確認することを推薦します。
このガイドでは、tf.keras.layers.Layer
メソッドに適用された tf.compat.v1.keras.utils.track_tf1_style_variables
モデリング shim を使用した移行コードの例について説明します。TF2 モデリング shim の詳細については、モデルマッピングガイドを参照してください。
このガイドでは、次の目的で使用できるアプローチについて詳しく説明します。
- 移行されたコードを使用してトレーニングモデルから得られた結果の正確性を検証する
- TensorFlow バージョン間でコードの数値的等価性を検証する
セットアップ
pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8
pip install -q tf-nightly
pip install -q tf_slim
import tensorflow as tf
import tensorflow.compat.v1 as v1
import numpy as np
import tf_slim as slim
import sys
from contextlib import contextmanager
2022-12-14 22:00:19.013155: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
!git clone --depth=1 https://github.com/tensorflow/models.git
import models.research.slim.nets.inception_resnet_v2 as inception
Cloning into 'models'... remote: Enumerating objects: 3590, done. remote: Counting objects: 100% (3590/3590), done. remote: Compressing objects: 100% (3006/3006), done. remote: Total 3590 (delta 942), reused 1502 (delta 530), pack-reused 0 Receiving objects: 100% (3590/3590), 47.08 MiB | 25.56 MiB/s, done. Resolving deltas: 100% (942/942), done.
重要なフォワードパスコードのチャンクを shim に入れる場合は、TF1.x と同じように動作していることを確認する必要があります。たとえば、TF-Slim Inception-Resnet-v2 モデル全体を次のように shim に入れることを検討してください。
# TF1 Inception resnet v2 forward pass based on slim layers
def inception_resnet_v2(inputs, num_classes, is_training):
with slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
return inception.inception_resnet_v2(inputs, num_classes, is_training=is_training)
class InceptionResnetV2(tf.keras.layers.Layer):
"""Slim InceptionResnetV2 forward pass as a Keras layer"""
def __init__(self, num_classes, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs, training=None):
is_training = training or False
# Slim does not accept `None` as a value for is_training,
# Keras will still pass `None` to layers to construct functional models
# without forcing the layer to always be in training or in inference.
# However, `None` is generally considered to run layers in inference.
with slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
return inception.inception_resnet_v2(
inputs, self.num_classes, is_training=is_training)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_143457/2131234657.py:8: The name tf.keras.utils.track_tf1_style_variables is deprecated. Please use tf.compat.v1.keras.utils.track_tf1_style_variables instead.
ここでは、このレイヤーは実際にはそのまますぐに完全に機能します(正確な正則化損失トラッキングを備えています)。
ただし、これは当たり前のことではありません。以下のステップに従って、実際に TF1.x と同じように動作していることを確認し、数値的に完全に等価であることを確認します。これらのステップは、フォワードパスのどの部分が TF1.x からの分岐を引き起こしているかを三角測量するのにも役立ちます(モデルの別の部分ではなく、モデルのフォワードパスで分岐が発生しているかどうかを特定します)。
ステップ 1: 変数が 1 回だけ作成されることを確認する
最初に、各呼び出しで変数が再利用され、毎回新しい変数が誤って作成されて使用されないようにモデルが正しく構築されていることを検証する必要があります。たとえば、モデルが新しい Keras レイヤーを作成したり、各フォワードパス呼び出しで tf.Variable
を呼び出す場合、変数のキャプチャに失敗し、毎回新しい変数を作成する可能性が高くなります。
以下は、モデルが新しい変数を作成している場合に、そのことを検出し、モデルのどの部分がそれを行っているかをデバッグするために使用できる 2 つのコンテキストマネージャースコープです。
@contextmanager
def assert_no_variable_creations():
"""Assert no variables are created in this context manager scope."""
def invalid_variable_creator(next_creator, **kwargs):
raise ValueError("Attempted to create a new variable instead of reusing an existing one. Args: {}".format(kwargs))
with tf.variable_creator_scope(invalid_variable_creator):
yield
@contextmanager
def catch_and_raise_created_variables():
"""Raise all variables created within this context manager scope (if any)."""
created_vars = []
def variable_catcher(next_creator, **kwargs):
var = next_creator(**kwargs)
created_vars.append(var)
return var
with tf.variable_creator_scope(variable_catcher):
yield
if created_vars:
raise ValueError("Created vars:", created_vars)
スコープ内で変数を作成しようとすると、最初のスコープ(assert_no_variable_creations()
)は、すぐにエラーを発生します。これにより、スタックトレースを調べて(対話型デバッグを使用して)、既存の変数を再利用する代わりに、変数を作成したコード行を正確に把握できます。
2 番目のスコープ(catch_and_raise_created_variables()
)は、変数が作成された場合、スコープの最後で例外を発生させます。この例外には、スコープで作成されたすべての変数のリストが含まれます。これは、一般的なパターンを見つけることができる場合に、モデルが作成しているすべての重みのセットが何であるかを把握するのに役立ちます。ただし、これらの変数が作成された正確なコード行を特定するにはあまり役に立ちません。
以下の両方のスコープを使用して、shim ベースの InceptionResnetV2 レイヤーが最初の呼び出し後に新しい変数を作成せずに再利用していることを確認します。
model = InceptionResnetV2(1000)
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
# Create all weights on the first call
model(inputs)
# Verify that no new weights are created in followup calls
with assert_no_variable_creations():
model(inputs)
with catch_and_raise_created_variables():
model(inputs)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:2212: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead. warnings.warn('`layer.apply` is deprecated and ' /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/legacy_tf_layers/core.py:332: UserWarning: `tf.layers.flatten` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Flatten` instead. warnings.warn('`tf.layers.flatten` is deprecated and '
以下の例では、既存の重みを再利用する代わりに、毎回誤って新しい重みを作成するレイヤーで、これらのデコレータがどのように機能するかを観察できます。
class BrokenScalingLayer(tf.keras.layers.Layer):
"""Scaling layer that incorrectly creates new weights each time:"""
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
var = tf.Variable(initial_value=2.0)
bias = tf.Variable(initial_value=2.0, name='bias')
return inputs * var + bias
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
try:
with assert_no_variable_creations():
model(inputs)
except ValueError as err:
import traceback
traceback.print_exc()
Traceback (most recent call last): File "/tmpfs/tmp/ipykernel_143457/1128777590.py", line 7, in <module> model(inputs) File "/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/utils/traceback_utils.py", line 70, in error_handler raise e.with_traceback(filtered_tb) from None File "/tmpfs/tmp/ipykernel_143457/3224979076.py", line 6, in call var = tf.Variable(initial_value=2.0) File "/tmpfs/tmp/ipykernel_143457/1829430118.py", line 5, in invalid_variable_creator raise ValueError("Attempted to create a new variable instead of reusing an existing one. Args: {}".format(kwargs)) ValueError: Exception encountered when calling layer 'broken_scaling_layer' (type BrokenScalingLayer). Attempted to create a new variable instead of reusing an existing one. Args: {'initial_value': 2.0, 'trainable': None, 'validate_shape': True, 'caching_device': None, 'name': None, 'variable_def': None, 'dtype': None, 'import_scope': None, 'constraint': None, 'synchronization': <VariableSynchronization.AUTO: 0>, 'aggregation': <VariableAggregation.NONE: 0>, 'shape': None, 'experimental_enable_variable_lifting': None} Call arguments received by layer 'broken_scaling_layer' (type BrokenScalingLayer): • inputs=tf.Tensor(shape=(1, 299, 299, 3), dtype=float32)
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
try:
with catch_and_raise_created_variables():
model(inputs)
except ValueError as err:
print(err)
('Created vars:', [<tf.Variable 'broken_scaling_layer_1/Variable:0' shape=() dtype=float32, numpy=2.0>, <tf.Variable 'broken_scaling_layer_1/bias:0' shape=() dtype=float32, numpy=2.0>])
重みを一度だけ作成し、毎回再利用するようにすることで、レイヤーを修正できます。
class FixedScalingLayer(tf.keras.layers.Layer):
"""Scaling layer that incorrectly creates new weights each time:"""
def __init__(self):
super().__init__()
self.var = None
self.bias = None
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
if self.var is None:
self.var = tf.Variable(initial_value=2.0)
self.bias = tf.Variable(initial_value=2.0, name='bias')
return inputs * self.var + self.bias
model = FixedScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
with assert_no_variable_creations():
model(inputs)
with catch_and_raise_created_variables():
model(inputs)
トラブルシューティング
以下は、モデルが既存の重みを再利用せずに誤って新しい重みを作成してしまう一般的な理由です。
- 既に作成された
tf.Variables
を再利用せずに、明示的なtf.Variable
呼び出しを使用してしまう場合、最初に作成されていないかどうかを確認してから、既存のものを再利用します。 - (
tf.compat.v1.layers
とは対照的に)毎回フォワードパスで Keras レイヤーまたはモデルを直接作成してしまう場合、最初に作成されていないかどうかを確認して、既存のものを再利用します。 tf.compat.v1.layers
の上に構築されていて、すべてのcompat.v1.layers
に明示的な名前を割り当てたり、名前付きvariable_scope
内でcompat.v1 .layer
の使用をラップできず、自動生成されたレイヤー名が各モデル呼び出しでインクリメントされてしまう場合、tf.compat.v1.layers
の使用をすべてラップする shim でデコレートされたメソッド内に名前付きのtf.compat.v1.variable_scope
を配置します。
ステップ 2: 変数の数、名前、形状が一致していることを確認する
2 番目のステップは、TF2 で実行されているレイヤーが対応するコードが TF1.x と同じ形状で同じ数の重みを作成することを確認することです。
以下に示すように、これらが一致することを確認するために手動での確認と、単体テストでのプログラムによる確認を組み合わせることができます。
# Build the forward pass inside a TF1.x graph, and
# get the counts, shapes, and names of the variables
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
tf1_variable_names_and_shapes = {
var.name: (var.trainable, var.shape) for var in tf.compat.v1.global_variables()}
num_tf1_variables = len(tf.compat.v1.global_variables())
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer_v1.py:1694: UserWarning: `layer.apply` is deprecated and will be removed in a future version. Please use `layer.__call__` method instead. warnings.warn('`layer.apply` is deprecated and '
次に、TF2 の shim によりラップされたレイヤーに対して同じことを実行します。重みを取得する前に、モデルも複数回呼び出されることに注意してください。これは、変数の再利用を効果的にテストするために行われます。
height, width = 299, 299
num_classes = 1000
model = InceptionResnetV2(num_classes)
# The weights will not be created until you call the model
inputs = tf.ones( (1, height, width, 3))
# Call the model multiple times before checking the weights, to verify variables
# get reused rather than accidentally creating additional variables
out, endpoints = model(inputs, training=False)
out, endpoints = model(inputs, training=False)
# Grab the name: shape mapping and the total number of variables separately,
# because in TF2 variables can be created with the same name
num_tf2_variables = len(model.variables)
tf2_variable_names_and_shapes = {
var.name: (var.trainable, var.shape) for var in model.variables}
# Verify that the variable counts, names, and shapes all match:
assert num_tf1_variables == num_tf2_variables
assert tf1_variable_names_and_shapes == tf2_variable_names_and_shapes
Shim ベースの InceptionResnetV2 レイヤーは、このテストに合格しています。ただし、一致しない場合は、差分(テキストまたはその他)を実行して、差分がどこにあるかを確認できます。
これにより、モデルのどの部分が期待どおりに動作していないかが分かります。Eager execution では、pdb、インタラクティブなデバッグ、およびブレークポイントを使用して、疑わしいと思われるモデルの部分を掘り下げ、問題が何なのかをより深くデバッグできます。
トラブルシューティング
明示的な
tf.Variable
呼び出しと Keras レイヤー/モデルによって直接作成された変数の名前に細心の注意を払ってください。他のすべてが正常に機能している場合でもそれらの変数名生成セマンティクスは、TF1.x Graph と Eager execution およびtf.function
などの TF2 関数との間でわずかに異なる可能性があるためです。このような場合は、わずかに異なる命名セマンティクスを考慮してテストを調整してください。TF1.x の変数コレクションによってキャプチャされた場合でも、
tf.Variable
、tf.keras.layers.Layer
、またはtf.keras.Model
がトレーニングループのフォワードパスが TF2 変数リストにない場合があります。これを修正するには、フォワードパスが作成する変数/レイヤー/モデルをモデルのインスタンス属性に割り当てます。詳細については、こちらを参照してください。
ステップ 3: すべての変数をリセットし、ランダム性をすべて無効にして数値の等価性を確認する
次のステップでは、(推論中などに)乱数生成が含まれないようにモデルを修正するときに、実際の出力と正則化損失トラッキングの両方の数値的等価性を検証します。
正確な方法は、特定のモデルに依存する場合がありますが、ほとんどのモデル(このモデルなど)では、次の方法でこれを行うことができます。
- 重みをランダム性なしで同じ値に初期化します。そのためには、作成後に固定値にリセットします。
- モデルを推論モードで実行して、ランダム性の原因となる可能性のあるドロップアウトレイヤーがトリガーされないようにします。
次のコードは、この方法で TF1.x と TF2 の結果を比較する方法を示しています。
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
# Rather than running the global variable initializers,
# reset all variables to a constant value
var_reset = tf.group([var.assign(tf.ones_like(var) * 0.001) for var in tf.compat.v1.global_variables()])
sess.run(var_reset)
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
tf1_output[0][:5]
Regularization loss: 0.001182976 array([0.00299837, 0.00299837, 0.00299837, 0.00299837, 0.00299837], dtype=float32)
TF2 の結果を取得します。
height, width = 299, 299
num_classes = 1000
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
# Call the model once to create the weights
out, endpoints = model(inputs, training=False)
# Reset all variables to the same fixed value as above, with no randomness
for var in model.variables:
var.assign(tf.ones_like(var) * 0.001)
tf2_output, endpoints = model(inputs, training=False)
# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
tf2_output[0][:5]
Regularization loss: tf.Tensor(0.0011829757, shape=(), dtype=float32) <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.00299837, 0.00299837, 0.00299837, 0.00299837, 0.00299837], dtype=float32)>
# Create a dict of tolerance values
tol_dict={'rtol':1e-06, 'atol':1e-05}
# Verify that the regularization loss and output both match
# when we fix the weights and avoid randomness by running inference:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
ランダム性のソースを削除すると、TF1.x と TF2 の間で数値が一致し、TF2 互換の InceptionResnetV2
レイヤーがテストに合格します。
独自のモデルで結果が分岐しているのを観察した場合は、出力または pdb と対話型デバッグを使用して、結果が分岐し始める場所と理由を特定できます。Eager execution を使用すると、これが大幅に容易になります。また、アブレーションアプローチを使用して、モデルのごく一部のみを固定中間入力で実行し、分岐が発生する場所を分離することもできます。
便利なことに、多くのスリムネット(およびその他のモデル)は、プローブ可能な中間エンドポイントも公開しています。
ステップ 4: 乱数生成を調整し、トレーニングと推論における数値等価性をチェックする
最後のステップは、TF2 モデルが TF1.x モデルと数値的に一致することを確認することです。これは、変数の初期化およびフォワードパス自体(フォワードパス中のドロップアウトレイヤーなど)での乱数生成を考慮する場合でも同様です。
これを行うには、以下のテストツールを使用して、乱数生成のセマンティクスを TF1.x Graph/Session と Eager execution の間で一致させます。
以前の TF1 Graph/Session と TF2 Eager execution は、異なるステートフルな乱数生成セマンティクスを使用します
tf.compat.v1.Session
で、シードが指定されていない場合、乱数の生成は、ランダムな演算が追加された時点で Graph にある演算の数と、その Graph の実行回数に依存します。Eager execution では、ステートフルな乱数の生成は、グローバルシード、演算のランダムシード、および指定されたランダムシードを使用した演算が実行される回数に依存します。詳細については、tf.random.set_seed
を参照してください。
次の v1.keras.utils.DeterministicRandomTestTool
クラスは、コンテキストマネージャ scope()
を提供し、 TF1 Graphs/Session と Eager execution の両方でステートフルなランダム演算が同じシードを使用できるようになります。
このツールには、次の 2 つのテストモードがあります。
constant
は、呼び出された回数に関係なく、1 つの演算ごとに同じシードを使用します。num_random_ops
は、以前に観測されたステートフルなランダム演算の数を演算シードとして使用します。
これは、変数の作成と初期化に使用されるステートフルなランダム演算と、計算で使用されるステートフルなランダム演算(ドロップアウトレイヤーなど)の両方に適用されます。
このツールを使用して、Session と Eager execution の間でステートフルな乱数生成を一致させる方法を示すために、3 つのランダムテンソルを生成します。
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
graph_a, graph_b, graph_c = sess.run([a, b, c])
graph_a, graph_b, graph_c
(array([[2.5063772], [2.7488918], [1.4839486]], dtype=float32), array([[2.5063772, 2.7488918, 1.4839486], [1.5633398, 2.1358476, 1.3693532], [0.3598416, 1.8287641, 2.5314465]], dtype=float32), array([[2.5063772, 2.7488918, 1.4839486], [1.5633398, 2.1358476, 1.3693532], [0.3598416, 1.8287641, 2.5314465]], dtype=float32))
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
a, b, c
(<tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[2.5063772], [2.7488918], [1.4839486]], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[2.5063772, 2.7488918, 1.4839486], [1.5633398, 2.1358476, 1.3693532], [0.3598416, 1.8287641, 2.5314465]], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[2.5063772, 2.7488918, 1.4839486], [1.5633398, 2.1358476, 1.3693532], [0.3598416, 1.8287641, 2.5314465]], dtype=float32)>)
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict)
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
ただし、constant
モードでは、b
と c
は同じシードで同じ形状で生成されているため、正確に同じ値になります。
np.testing.assert_allclose(b.numpy(), c.numpy(), **tol_dict)
順位トレース
constant
モードで一部の乱数が一致して数値的等価性テストの信頼性が低下することが懸念される場合(たとえば、複数の重みが同じ初期化を行う場合)、これを回避するには num_random_ops
モードを使用します。num_random_ops
モードでは、生成される乱数は、プログラム内のランダム演算の順位に依存します。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
graph_a, graph_b, graph_c = sess.run([a, b, c])
graph_a, graph_b, graph_c
(array([[2.5063772], [2.7488918], [1.4839486]], dtype=float32), array([[0.45038545, 1.9197761 , 2.4536333 ], [1.0371652 , 2.9898582 , 1.924583 ], [0.25679827, 1.6579313 , 2.8418403 ]], dtype=float32), array([[2.9634383 , 1.0862181 , 2.6042497 ], [0.70099247, 2.3920312 , 1.0470468 ], [0.18173039, 0.8359269 , 1.0508587 ]], dtype=float32))
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
a, b, c
(<tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[2.5063772], [2.7488918], [1.4839486]], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[0.45038545, 1.9197761 , 2.4536333 ], [1.0371652 , 2.9898582 , 1.924583 ], [0.25679827, 1.6579313 , 2.8418403 ]], dtype=float32)>, <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[2.9634383 , 1.0862181 , 2.6042497 ], [0.70099247, 2.3920312 , 1.0470468 ], [0.18173039, 0.8359269 , 1.0508587 ]], dtype=float32)>)
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict )
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
# Demonstrate that with the 'num_random_ops' mode,
# b & c took on different values even though
# their generated shape was the same
assert not np.allclose(b.numpy(), c.numpy(), **tol_dict)
ただし、このモードでの乱数生成はプログラムの順位に影響されるため、次の生成された乱数は一致しないことに注意してください。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
assert not np.allclose(a.numpy(), a_prime.numpy())
assert not np.allclose(b.numpy(), b_prime.numpy())
順位トレースによるデバッグのバリエーションを可能にするために、num_random_ops
モードで DeterministicRandomTestTool
を使用すると、operation_seed
プロパティでトレースされたランダム演算の数を確認できます。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
print(random_tool.operation_seed)
a = tf.random.uniform(shape=(3,1))
a = a * 3
print(random_tool.operation_seed)
b = tf.random.uniform(shape=(3,3))
b = b * 3
print(random_tool.operation_seed)
0 1 2
テストでさまざまな順位トレースを考慮する必要がある場合は、自動インクリメント operation_seed
を明示的に設定することもできます。たとえば、これを使用して、2 つの異なるプログラムの順位間で乱数生成を一致させることができます。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
print(random_tool.operation_seed)
a = tf.random.uniform(shape=(3,1))
a = a * 3
print(random_tool.operation_seed)
b = tf.random.uniform(shape=(3,3))
b = b * 3
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
random_tool.operation_seed = 1
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
random_tool.operation_seed = 0
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
np.testing.assert_allclose(a.numpy(), a_prime.numpy(), **tol_dict)
np.testing.assert_allclose(b.numpy(), b_prime.numpy(), **tol_dict)
0 1
ただし、DeterministicRandomTestTool
では、すでに使用されている演算シードの再利用が許可されていないため、自動インクリメントシーケンスが重複しないようにしてください。これは、Eager execution では同じ演算シードの後続の使用に対して異なる数が生成されるためです。TF1 Graphs と Session では異なる数は生成されません。そのため、エラーを発生させると、Session と Eager ステートフル乱数の生成を一致させることができます。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
random_tool.operation_seed = 1
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
random_tool.operation_seed = 0
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
try:
c = tf.random.uniform(shape=(3,1))
raise RuntimeError("An exception should have been raised before this, " +
"because the auto-incremented operation seed will " +
"overlap an already-used value")
except ValueError as err:
print(err)
This `DeterministicRandomTestTool` object is trying to re-use the already-used operation seed 1. It cannot guarantee random numbers will match between eager and sessions when an operation seed is reused. You most likely set `operation_seed` explicitly but used a value that caused the naturally-incrementing operation seed sequences to overlap with an already-used seed.
推論の検証
DeterministicRandomTestTool
を使用して、ランダムな重みの初期化を使用している場合でも、InceptionResnetV2
モデルが推論で一致することを確認できるようになりました。プログラムの順位が一致するより強力なテスト条件を得るには、num_random_ops
モードを使用します。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
Regularization loss: 1.2254326
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=False)
# Grab the regularization loss as well
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
Regularization loss: tf.Tensor(1.2254325, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
トレーニングの検証
DeterministicRandomTestTool
は、すべてのステートフルなランダム演算(重みの初期化とドロップアウトレイヤーなどの計算を含む)で機能するため、これを使用して、トレーニングモードでもモデルが一致することを確認できます。ステートフルなランダム演算のプログラムの順位が一致するため、num_random_ops
モードを再度使用できます。
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/layers/normalization/batch_normalization.py:581: _colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. Instructions for updating: Colocations handled automatically by placer. Regularization loss: 1.22548
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=True)
# Grab the regularization loss as well
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
Regularization loss: tf.Tensor(1.2254798, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
これで、tf.keras.layers.Layer
の周りのデコレータで Eager execution されている InceptionResnetV2
モデルが、TF1 Graph と Session で実行されているスリムネットワークと数値的に一致することを確認できました。
注意: num_random_ops
モードで DeterministicRandomTestTool
を使用する場合、数値的等価性のためにテスト時に tf.keras.layers.Layer
メソッドデコレータを直接使用して呼び出すことを推薦します。Keras functional モデルまたは他の Keras モデル内に埋め込むと、TF1.x Graph/Session と Eager execution を比較するときに、ステートフルなランダム演算の順位トレースに違いが生じ、正確に一致させるのが難しくなる可能性があります。
たとえば、InceptionResnetV2
レイヤーを training=True
で直接呼び出すと、変数の初期化がネットワークの作成順位に従ってドロップアウト順位でインターリーブされます。
一方、最初に tf.keras.layers.Layer
デコレータを Keras functional モデルに配置してから、そのモデルを training=True
で呼び出すことは、すべての変数を初期化し、ドロップアウトレイヤーを使用することと同じです。これにより、異なる順位トレースと異なる乱数セットが生成されます。
ただし、デフォルトの mode='constant'
は、これらの順位トレースの違いに影響されず、レイヤーを Keras functional モデルに埋め込む場合でも、追加の作業なしで渡せます。
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Get the outputs & regularization losses
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
Regularization loss: 1.2239965
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
keras_input = tf.keras.Input(shape=(height, width, 3))
layer = InceptionResnetV2(num_classes)
model = tf.keras.Model(inputs=keras_input, outputs=layer(keras_input))
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=True)
# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/keras/engine/base_layer.py:1345: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. warnings.warn('`layer.updates` will be removed in a future version. ' /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/legacy_tf_layers/base.py:627: UserWarning: `layer.updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. self.updates, tf.compat.v1.GraphKeys.UPDATE_OPS Regularization loss: tf.Tensor(1.2239964, shape=(), dtype=float32)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
ステップ 3b 、4b(オプション): 既存のチェックポイントを使用したテスト
上記のステップ 3 またはステップ 4 の後、既存の名前ベースのチェックポイントがある場合は、そこから開始するときに数値的等価性テストを実行すると便利です。これにより、レガシーチェックポイントの読み込みが正しく機能していることと、モデル自体が正しく機能していることの両方をテストできます。TF1.x チェックポイントの再利用ガイドでは、既存の TF1.x チェックポイントを再利用して TF2 チェックポイントに移行する方法について説明されています。
追加のテストとトラブルシューティング
数値的等価性テストをさらに追加する場合、勾配計算(またはオプティマイザーの更新)の一致を検証するテストを追加することもできます。
バックプロパゲーションと勾配の計算は、モデルのフォワードパスよりも浮動小数点の数値が不安定になる傾向があります。これは、トレーニングの分離されていない部分の等価性をテストすると、完全に Eager execution を実行した場合と TF1 Graph との間に大きな数値上の違いが見られる可能性があることを意味します。これは、Graph 内の部分式をより少ない数学的演算に置き換えたりする TensorFlow Graph の最適化が原因である可能性があります。
これが当てはまる可能性があるかどうかを特定するには、TF1 コードを、純粋な Eager 計算ではなく、tf.function
(TF1 Graph のようなグラフ最適化パスを適用する)内で行われている TF2 計算と比較できます。または、TF1 計算の前に tf.config.optimizer.set_experimental_options
を使用して "arithmetic_optimization"
などの最適化パスを無効にして、結果が TF2 計算結果と数値的に近い値になるかどうかを確認することもできます。実際のトレーニングの実行では、パフォーマンス上の理由から最適化パスを有効にして tf.function
を使用することを推薦しますが、数値等価性の単体テストではそれらを無効にすることが役立つ場合があります。
同様に、tf.compat.v1.train
オプティマイザーと TF2 オプティマイザーは、それらが表す数式が同じであっても、TF2 オプティマイザーにはわずかに異なる浮動小数点数値プロパティがあります。これがトレーニングの実行で問題になる可能性は低いですが、等価性単体テストではより高い数値許容誤差が必要になる場合があります。