View on TensorFlow.org | Run in Google Colab | View on GitHub | Download notebook |
When migrating your TensorFlow code from TF1.x to TF2, it is a good practice to ensure that your migrated code behaves the same way in TF2 as it did in TF1.x.
This guide covers migration code examples with the tf.compat.v1.keras.utils.track_tf1_style_variables
modeling shim applied to tf.keras.layers.Layer
methods. Read the model mapping guide to find out more about the TF2 modeling shims.
This guide details approaches you can use to:
- Validate the correctness of the results obtained from training models using the migrated code
- Validate the numerical equivalence of your code across TensorFlow versions
Setup
pip uninstall -y -q tensorflow
# Install tf-nightly as the DeterministicRandomTestTool is available only in
# Tensorflow 2.8
pip install -q tf-nightly
pip install -q tf_slim
import tensorflow as tf
import tensorflow.compat.v1 as v1
import numpy as np
import tf_slim as slim
import sys
from contextlib import contextmanager
!git clone --depth=1 https://github.com/tensorflow/models.git
import models.research.slim.nets.inception_resnet_v2 as inception
If you're putting a nontrivial chunk of forward pass code into the shim, you want to know that it is behaving the same way as it did in TF1.x. For example, consider trying to put an entire TF-Slim Inception-Resnet-v2 model into the shim as such:
# TF1 Inception resnet v2 forward pass based on slim layers
def inception_resnet_v2(inputs, num_classes, is_training):
with slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
return inception.inception_resnet_v2(inputs, num_classes, is_training=is_training)
class InceptionResnetV2(tf.keras.layers.Layer):
"""Slim InceptionResnetV2 forward pass as a Keras layer"""
def __init__(self, num_classes, **kwargs):
super().__init__(**kwargs)
self.num_classes = num_classes
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs, training=None):
is_training = training or False
# Slim does not accept `None` as a value for is_training,
# Keras will still pass `None` to layers to construct functional models
# without forcing the layer to always be in training or in inference.
# However, `None` is generally considered to run layers in inference.
with slim.arg_scope(
inception.inception_resnet_v2_arg_scope(batch_norm_scale=True)):
return inception.inception_resnet_v2(
inputs, self.num_classes, is_training=is_training)
As it so happens, this layer actually works perfectly fine out of the box (complete with accurate regularization loss tracking).
However, this is not something you want to take for granted. Follow the below steps to verify that it is actually behaving as it did in TF1.x, down to observing perfect numerical equivalence. These steps can also help you triangulate what part of the forward pass is causing a divergence from TF1.x (identify if the divergence arises in the model forward pass as opposed to a different part of the model).
Step 1: Verify variables are only created once
The very first thing you should verify is that you have correctly built the model in a way that reuses variables in each call rather than accidentally creating and using new variables each time. For example, if your model creates a new Keras layer or calls tf.Variable
in each forward pass call then it is most likely failing to capture variables and creating new ones each time.
Below are two context manager scopes you can use to detect when your model is creating new variables and debug which part of the model is doing it.
@contextmanager
def assert_no_variable_creations():
"""Assert no variables are created in this context manager scope."""
def invalid_variable_creator(next_creator, **kwargs):
raise ValueError("Attempted to create a new variable instead of reusing an existing one. Args: {}".format(kwargs))
with tf.variable_creator_scope(invalid_variable_creator):
yield
@contextmanager
def catch_and_raise_created_variables():
"""Raise all variables created within this context manager scope (if any)."""
created_vars = []
def variable_catcher(next_creator, **kwargs):
var = next_creator(**kwargs)
created_vars.append(var)
return var
with tf.variable_creator_scope(variable_catcher):
yield
if created_vars:
raise ValueError("Created vars:", created_vars)
The first scope (assert_no_variable_creations()
) will raise an error immediately once you try creating a variable within the scope. This allows you to inspect the stacktrace (and use interactive debugging) to figure out exactly what lines of code created a variable instead of reusing an existing one.
The second scope (catch_and_raise_created_variables()
) will raise an exception at the end of the scope if any variables ended up being created. This exception will include the list of all variables created in the scope. This is useful for figuring out what the set of all weights your model is creating is in case you can spot general patterns. However, it is less useful for identifying the exact lines of code where those variables got created.
Use both scopes below to verify that the shim-based InceptionResnetV2 layer does not create any new variables after the first call (presumably reusing them).
model = InceptionResnetV2(1000)
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
# Create all weights on the first call
model(inputs)
# Verify that no new weights are created in followup calls
with assert_no_variable_creations():
model(inputs)
with catch_and_raise_created_variables():
model(inputs)
In the example below, observe how these decorators work on a layer that incorrectly creates new weights each time instead of reusing existing ones.
class BrokenScalingLayer(tf.keras.layers.Layer):
"""Scaling layer that incorrectly creates new weights each time:"""
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
var = tf.Variable(initial_value=2.0)
bias = tf.Variable(initial_value=2.0, name='bias')
return inputs * var + bias
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
try:
with assert_no_variable_creations():
model(inputs)
except ValueError as err:
import traceback
traceback.print_exc()
model = BrokenScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
try:
with catch_and_raise_created_variables():
model(inputs)
except ValueError as err:
print(err)
You can fix the layer by making sure it only creates the weights once and then reuses them each time.
class FixedScalingLayer(tf.keras.layers.Layer):
"""Scaling layer that incorrectly creates new weights each time:"""
def __init__(self):
super().__init__()
self.var = None
self.bias = None
@tf.compat.v1.keras.utils.track_tf1_style_variables
def call(self, inputs):
if self.var is None:
self.var = tf.Variable(initial_value=2.0)
self.bias = tf.Variable(initial_value=2.0, name='bias')
return inputs * self.var + self.bias
model = FixedScalingLayer()
inputs = tf.ones( (1, height, width, 3))
model(inputs)
with assert_no_variable_creations():
model(inputs)
with catch_and_raise_created_variables():
model(inputs)
Troubleshooting
Here are some common reasons why your model might accidentally be creating new weights instead of reusing existing ones:
- It uses an explicit
tf.Variable
call without reusing already-createdtf.Variables
. Fix this by first checking if it has not been created then reusing the existing ones. - It creates a Keras layer or model directly in the forward pass each time (as opposed to
tf.compat.v1.layers
). Fix this by first checking if it has not been created then reusing the existing ones. - It is built on top of
tf.compat.v1.layers
but fails to assign allcompat.v1.layers
an explicit name or to wrap yourcompat.v1.layer
usage inside of a namedvariable_scope
, causing the autogenerated layer names to increment in each model call. Fix this by putting a namedtf.compat.v1.variable_scope
inside your shim-decorated method that wraps all of yourtf.compat.v1.layers
usage.
Step 2: Check that variable counts, names, and shapes match
The second step is to make sure your layer running in TF2 creates the same number of weights, with the same shapes, as the corresponding code does in TF1.x.
You can do a mix of manually checking them to see that they match, and doing the checks programmatically in a unit test as shown below.
# Build the forward pass inside a TF1.x graph, and
# get the counts, shapes, and names of the variables
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
tf1_variable_names_and_shapes = {
var.name: (var.trainable, var.shape) for var in tf.compat.v1.global_variables()}
num_tf1_variables = len(tf.compat.v1.global_variables())
Next, do the same for the shim-wrapped layer in TF2. Notice that the model is also called multiple times before grabbing the weights. This is done to effectively test for variable reuse.
height, width = 299, 299
num_classes = 1000
model = InceptionResnetV2(num_classes)
# The weights will not be created until you call the model
inputs = tf.ones( (1, height, width, 3))
# Call the model multiple times before checking the weights, to verify variables
# get reused rather than accidentally creating additional variables
out, endpoints = model(inputs, training=False)
out, endpoints = model(inputs, training=False)
# Grab the name: shape mapping and the total number of variables separately,
# because in TF2 variables can be created with the same name
num_tf2_variables = len(model.variables)
tf2_variable_names_and_shapes = {
var.name: (var.trainable, var.shape) for var in model.variables}
# Verify that the variable counts, names, and shapes all match:
assert num_tf1_variables == num_tf2_variables
assert tf1_variable_names_and_shapes == tf2_variable_names_and_shapes
The shim-based InceptionResnetV2 layer passes this test. However, in the case where they don't match, you can run it through a diff (text or other) to see where the differences are.
This can provide a clue as to what part of the model isn't behaving as expected. With eager execution you can use pdb, interactive debugging, and breakpoints to dig into the parts of the model that seem suspicious, and debug what is going wrong in more depth.
Troubleshooting
Pay close attention to the names of any variables created directly by explicit
tf.Variable
calls and Keras layers/models as their variable name generation semantics may differ slightly between TF1.x graphs and TF2 functionality such as eager execution andtf.function
even if everything else is working properly. If this is the case for you, adjust your test to account for any slightly different naming semantics.You may sometimes find that the
tf.Variable
s,tf.keras.layers.Layer
s, ortf.keras.Model
s created in your training loop's forward pass are missing from your TF2 variables list even if they were captured by the variables collection in TF1.x. Fix this by assigning the variables/layers/models that your forward pass creates to instance attributes in your model. See here for more info.
Step 3: Reset all variables, check numerical equivalence with all randomness disabled
The next step is to verify numerical equivalence for both the actual outputs and the regularization loss tracking when you fix the model such that there is no random number generation involved (such as during inference).
The exact way to do this may depend on your specific model, but in most models (such as this one), you can do this by:
- Initializing the weights to the same value with no randomness. This can be done by resetting them to a fixed value after they have been created.
- Running the model in inference mode to avoid triggering any dropout layers which can be sources of randomness.
The following code demonstrates how you can compare the TF1.x and TF2 results this way.
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
# Rather than running the global variable initializers,
# reset all variables to a constant value
var_reset = tf.group([var.assign(tf.ones_like(var) * 0.001) for var in tf.compat.v1.global_variables()])
sess.run(var_reset)
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
tf1_output[0][:5]
Get the TF2 results.
height, width = 299, 299
num_classes = 1000
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
# Call the model once to create the weights
out, endpoints = model(inputs, training=False)
# Reset all variables to the same fixed value as above, with no randomness
for var in model.variables:
var.assign(tf.ones_like(var) * 0.001)
tf2_output, endpoints = model(inputs, training=False)
# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
tf2_output[0][:5]
# Create a dict of tolerance values
tol_dict={'rtol':1e-06, 'atol':1e-05}
# Verify that the regularization loss and output both match
# when we fix the weights and avoid randomness by running inference:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
The numbers match between TF1.x and TF2 when you remove sources of randomness, and the TF2-compatible InceptionResnetV2
layer passes the test.
If you are observing the results diverging for your own models, you can use printing or pdb and interactive debugging to identify where and why the results start to diverge. Eager execution can make this significantly easier. You can also use an ablation approach to run only small portions of the model on fixed intermediate inputs and isolate where the divergence happens.
Conveniently, many slim nets (and other models) also expose intermediate endpoints that you can probe.
Step 4: Align random number generation, check numerical equivalence in both training and inference
The final step is to verify that the TF2 model numerically matches the TF1.x model, even when accounting for random number generation in variable initialization and in the forward pass itself (such as dropout layers during the forward pass).
You can do this by using the testing tool below to make random number generation semantics match between TF1.x graphs/sessions and eager execution.
TF1 legacy graphs/sessions and TF2 eager execution use different stateful random number generation semantics.
In tf.compat.v1.Session
s, if no seeds are specified, the random number generation depends on how many operations are in the graph at the time when the random operation is added, and how many times the graph is run. In eager execution, stateful random number generation depends on the global seed, the operation random seed, and how many times the operation with the operation with the given random seed is run. See
tf.random.set_seed
for more info.
The following v1.keras.utils.DeterministicRandomTestTool
class provides a context manager scope()
that can make stateful random operations use the same seed across both TF1 graphs/sessions and eager execution.
The tool provides two testing modes:
constant
which uses the same seed for every single operation no matter how many times it has been called and,num_random_ops
which uses the number of previously-observed stateful random operations as the operation seed.
This applies both to the stateful random operations used for creating and initializing variables, and to the stateful random operations used in computation (such as for dropout layers).
Generate three random tensors to show how to use this tool to make stateful random number generation match between sessions and eager execution.
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
graph_a, graph_b, graph_c = sess.run([a, b, c])
graph_a, graph_b, graph_c
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
a, b, c
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict)
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
However, notice that in constant
mode, because b
and c
were generated with the same seed and have the same shape, they will have exactly the same values.
np.testing.assert_allclose(b.numpy(), c.numpy(), **tol_dict)
Trace order
If you are worried about some random numbers matching in constant
mode reducing your confidence in your numerical equivalence test (for example if several weights take on the same initializations), you can use the num_random_ops
mode to avoid this. In the num_random_ops
mode, the generated random numbers will depend on the ordering of random ops in the program.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
graph_a, graph_b, graph_c = sess.run([a, b, c])
graph_a, graph_b, graph_c
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
c = tf.random.uniform(shape=(3,3))
c = c * 3
a, b, c
# Demonstrate that the generated random numbers match
np.testing.assert_allclose(graph_a, a.numpy(), **tol_dict)
np.testing.assert_allclose(graph_b, b.numpy(), **tol_dict )
np.testing.assert_allclose(graph_c, c.numpy(), **tol_dict)
# Demonstrate that with the 'num_random_ops' mode,
# b & c took on different values even though
# their generated shape was the same
assert not np.allclose(b.numpy(), c.numpy(), **tol_dict)
However, notice that in this mode random generation is sensitive to program order, and so the following generated random numbers do not match.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
a = tf.random.uniform(shape=(3,1))
a = a * 3
b = tf.random.uniform(shape=(3,3))
b = b * 3
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
assert not np.allclose(a.numpy(), a_prime.numpy())
assert not np.allclose(b.numpy(), b_prime.numpy())
To allow for debugging variations due to tracing order, DeterministicRandomTestTool
in num_random_ops
mode allows you to see how many random operations have been traced with the operation_seed
property.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
print(random_tool.operation_seed)
a = tf.random.uniform(shape=(3,1))
a = a * 3
print(random_tool.operation_seed)
b = tf.random.uniform(shape=(3,3))
b = b * 3
print(random_tool.operation_seed)
If you need to account for varying trace order in your tests, you can even set the auto-incrementing operation_seed
explicitly. For example, you can use this to make random number generation match across two different program orders.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
print(random_tool.operation_seed)
a = tf.random.uniform(shape=(3,1))
a = a * 3
print(random_tool.operation_seed)
b = tf.random.uniform(shape=(3,3))
b = b * 3
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
random_tool.operation_seed = 1
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
random_tool.operation_seed = 0
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
np.testing.assert_allclose(a.numpy(), a_prime.numpy(), **tol_dict)
np.testing.assert_allclose(b.numpy(), b_prime.numpy(), **tol_dict)
However, DeterministicRandomTestTool
disallows reusing already-used operation seeds, so make sure the auto-incremented sequences cannot overlap. This is because eager execution generates different numbers for follow-on usages of the same operation seed while TF1 graphs and sessions do not, so raising an error helps keep session and eager stateful random number generation in line.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
random_tool.operation_seed = 1
b_prime = tf.random.uniform(shape=(3,3))
b_prime = b_prime * 3
random_tool.operation_seed = 0
a_prime = tf.random.uniform(shape=(3,1))
a_prime = a_prime * 3
try:
c = tf.random.uniform(shape=(3,1))
raise RuntimeError("An exception should have been raised before this, " +
"because the auto-incremented operation seed will " +
"overlap an already-used value")
except ValueError as err:
print(err)
Verifying Inference
You can now use the DeterministicRandomTestTool
to make sure the InceptionResnetV2
model matches in inference, even when using the random weight initialization. For a stronger test condition due to matching program order, use the num_random_ops
mode.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=False)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=False)
# Grab the regularization loss as well
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool:
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
Verifying Training
Because DeterministicRandomTestTool
works for all stateful random operations (including both weight initialization and computation such as dropout layers), you can use it to verify the models match in training mode as well. You can again use the num_random_ops
mode because the program order of the stateful random ops matches.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Grab the outputs & regularization loss
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model = InceptionResnetV2(num_classes)
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=True)
# Grab the regularization loss as well
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
You have now verified that the InceptionResnetV2
model running eagerly with decorators around tf.keras.layers.Layer
numerically matches the slim network running in TF1 graphs and sessions.
For example, calling the InceptionResnetV2
layer directly with training=True
interleaves variable initialization with the dropout order according to the network creation order.
On the other hand, first putting the tf.keras.layers.Layer
decorator in a Keras functional model and only then calling the model with training=True
is equivalent to initializing all variables then using the dropout layer. This produces a different tracing order and a different set of random numbers.
However, the default mode='constant'
is not sensitive to these differences in tracing order and will pass without extra work even when embedding the layer in a Keras functional model.
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
height, width = 299, 299
num_classes = 1000
inputs = tf.ones( (1, height, width, 3))
out, endpoints = inception_resnet_v2(inputs, num_classes, is_training=True)
# Initialize the variables
sess.run(tf.compat.v1.global_variables_initializer())
# Get the outputs & regularization losses
reg_losses = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES)
tf1_regularization_loss = sess.run(tf.math.add_n(reg_losses))
tf1_output = sess.run(out)
print("Regularization loss:", tf1_regularization_loss)
height, width = 299, 299
num_classes = 1000
random_tool = v1.keras.utils.DeterministicRandomTestTool()
with random_tool.scope():
keras_input = tf.keras.Input(shape=(height, width, 3))
layer = InceptionResnetV2(num_classes)
model = tf.keras.Model(inputs=keras_input, outputs=layer(keras_input))
inputs = tf.ones((1, height, width, 3))
tf2_output, endpoints = model(inputs, training=True)
# Get the regularization loss
tf2_regularization_loss = tf.math.add_n(model.losses)
print("Regularization loss:", tf2_regularization_loss)
# Verify that the regularization loss and output both match
# when using the DeterministicRandomTestTool
np.testing.assert_allclose(tf1_regularization_loss, tf2_regularization_loss.numpy(), **tol_dict)
np.testing.assert_allclose(tf1_output, tf2_output.numpy(), **tol_dict)
Step 3b or 4b (optional): Testing with pre-existing checkpoints
After step 3 or step 4 above, it can be useful to run your numerical equivalence tests when starting from pre-existing name-based checkpoints if you have some. This can test both that your legacy checkpoint loading is working correctly and that the model itself is working right. The Reusing TF1.x checkpoints guide covers how to reuse your pre-existing TF1.x checkpoints and transfer them over to TF2 checkpoints.
Additional Testing & Troubleshooting
As you add more numerical equivalence tests, you may also choose to add a test that verifies your gradient computation (or even your optimizer updates) match.
Backpropagation and gradient computation are more prone to floating point numerical instabilities than model forward passes. This means that as your equivalence tests cover more non-isolated parts of your training, you may begin to see non-trivial numerics differences between running fully eagerly and your TF1 graphs. This may be caused by TensorFlow's graph optimizations that do things such as replace subexpressions in a graph with fewer mathematical operations.
To isolate whether this is likely to be the case, you can compare your TF1 code to TF2 computation happening inside of a tf.function
(which applies graph optimization passes like your TF1 graph) rather than to a purely eager computation. Alternatively, you can try using tf.config.optimizer.set_experimental_options
to disable optimization passes such as "arithmetic_optimization"
before your TF1 computation to see if the result ends up numerically closer to your TF2 computation results. In your actual training runs it is recommended you use tf.function
with optimization passes enabled for performance reasons, but you may find it useful to disable them in your numerical equivalence unit tests.
Similarly, you may also find that tf.compat.v1.train
optimizers and TF2 optimizers have slightly different floating point numerics properties than TF2 optimizers, even if the mathematical formulas they are representing are the same. This is less likely to be an issue in your training runs, but it may require a higher numerical tolerance in equivalence unit tests.