View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
This notebook demonstrates how to debug a training pipeline when migrating to TensorFlow 2 (TF2). It consists of following components:
- Suggested steps and code samples for debugging training pipeline
- Tools for debugging
- Other related resources
One assumption is you have the TensorFlow 1 (TF1.x) code and trained models for comparison, and you want to build a TF2 model that achieves similar validation accuracy.
This notebook does NOT cover debugging performance issues for training/inference speed or memory usage.
Debugging workflow
Below is a general workflow for debugging your TF2 training pipelines. Note that you do not need to follow these steps in order. You can also use a binary search approach where you test the model in an intermediate step and narrow down the debugging scope.
Fix compile and runtime errors
Single forward pass validation (in a separate guide)
a. On single CPU device
- Verify variables are created only once
- Check variable counts, names, and shapes match
- Reset all variables, check numerical equivalence with all randomness disabled
- Align random number generation, check numerical equivalence in inference
- (Optional) Check checkpoints are loaded properly and TF1.x/TF2 models generate identical output
b. On single GPU/TPU device
c. With multi-device strategies
Model training numerical equivalence validation for a few steps (code samples available below)
a. Single training step validation using small and fixed data on single CPU device. Specifically, check numerical equivalence for the following components
- losses computation
- metrics
- learning rate
- gradient computation and update
b. Check statistics after training 3 or more steps to verify optimizer behaviors like the momentum, still with fixed data on single CPU device
c. On single GPU/TPU device
d. With multi-device strategies (check the intro for MultiProcessRunner at the bottom)
End-to-end convergence testing on real dataset
a. Check training behaviors with TensorBoard
- use simple optimizers e.g., SGD and simple distribution strategies e.g.
tf.distribute.OneDeviceStrategy
first - training metrics
- evaluation metrics
- figure out what the reasonable tolerance for inherent randomness is
b. Check equivalence with advanced optimizer/learning rate scheduler/distribution strategies
c. Check equivalence when using mixed precision
- use simple optimizers e.g., SGD and simple distribution strategies e.g.
Additional product benchmarks
Setup
# The `DeterministicRandomTestTool` is only available from Tensorflow 2.8:
pip install -q "tensorflow==2.9.*"
Single forward pass validation
Single forward pass validation, including checkpoint loading, is covered in a different colab.
import sys
import unittest
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as v1
2024-08-15 02:15:01.731254: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
Model training numerical equivalence validation for a few steps
Set up model configuration and prepare a fake dataset.
params = {
'input_size': 3,
'num_classes': 3,
'layer_1_size': 2,
'layer_2_size': 2,
'num_train_steps': 100,
'init_lr': 1e-3,
'end_lr': 0.0,
'decay_steps': 1000,
'lr_power': 1.0,
}
# make a small fixed dataset
fake_x = np.ones((2, params['input_size']), dtype=np.float32)
fake_y = np.zeros((2, params['num_classes']), dtype=np.int32)
fake_y[0][0] = 1
fake_y[1][1] = 1
step_num = 3
Define the TF1.x model.
# Assume there is an existing TF1.x model using estimator API
# Wrap the model_fn to log necessary tensors for result comparison
class SimpleModelWrapper():
def __init__(self):
self.logged_ops = {}
self.logs = {
'step': [],
'lr': [],
'loss': [],
'grads_and_vars': [],
'layer_out': []}
def model_fn(self, features, labels, mode, params):
out_1 = tf.compat.v1.layers.dense(features, units=params['layer_1_size'])
out_2 = tf.compat.v1.layers.dense(out_1, units=params['layer_2_size'])
logits = tf.compat.v1.layers.dense(out_2, units=params['num_classes'])
loss = tf.compat.v1.losses.softmax_cross_entropy(labels, logits)
# skip EstimatorSpec details for prediction and evaluation
if mode == tf.estimator.ModeKeys.PREDICT:
pass
if mode == tf.estimator.ModeKeys.EVAL:
pass
assert mode == tf.estimator.ModeKeys.TRAIN
global_step = tf.compat.v1.train.get_or_create_global_step()
lr = tf.compat.v1.train.polynomial_decay(
learning_rate=params['init_lr'],
global_step=global_step,
decay_steps=params['decay_steps'],
end_learning_rate=params['end_lr'],
power=params['lr_power'])
optmizer = tf.compat.v1.train.GradientDescentOptimizer(lr)
grads_and_vars = optmizer.compute_gradients(
loss=loss,
var_list=graph.get_collection(
tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES))
train_op = optmizer.apply_gradients(
grads_and_vars,
global_step=global_step)
# log tensors
self.logged_ops['step'] = global_step
self.logged_ops['lr'] = lr
self.logged_ops['loss'] = loss
self.logged_ops['grads_and_vars'] = grads_and_vars
self.logged_ops['layer_out'] = {
'layer_1': out_1,
'layer_2': out_2,
'logits': logits}
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
def update_logs(self, logs):
for key in logs.keys():
model_tf1.logs[key].append(logs[key])
The following v1.keras.utils.DeterministicRandomTestTool
class provides a context manager scope()
that can make stateful random operations use the same seed across both TF1 graphs/sessions and eager execution,
The tool provides two testing modes:
constant
which uses the same seed for every single operation no matter how many times it has been called and,num_random_ops
which uses the number of previously-observed stateful random operations as the operation seed.
This applies both to the stateful random operations used for creating and initializing variables, and to the stateful random operations used in computation (such as for dropout layers).
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_93773/2689227634.py:1: The name tf.keras.utils.DeterministicRandomTestTool is deprecated. Please use tf.compat.v1.keras.utils.DeterministicRandomTestTool instead.
Run the TF1.x model in graph mode. Collect statistics for first 3 training steps for numerical equivalence comparison.
with random_tool.scope():
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
model_tf1 = SimpleModelWrapper()
# build the model
inputs = tf.compat.v1.placeholder(tf.float32, shape=(None, params['input_size']))
labels = tf.compat.v1.placeholder(tf.float32, shape=(None, params['num_classes']))
spec = model_tf1.model_fn(inputs, labels, tf.estimator.ModeKeys.TRAIN, params)
train_op = spec.train_op
sess.run(tf.compat.v1.global_variables_initializer())
for step in range(step_num):
# log everything and update the model for one step
logs, _ = sess.run(
[model_tf1.logged_ops, train_op],
feed_dict={inputs: fake_x, labels: fake_y})
model_tf1.update_logs(logs)
2024-08-15 02:15:04.252686: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory 2024-08-15 02:15:04.252893: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublas.so.11'; dlerror: libcublas.so.11: cannot open shared object file: No such file or directory 2024-08-15 02:15:04.252990: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcublasLt.so.11'; dlerror: libcublasLt.so.11: cannot open shared object file: No such file or directory 2024-08-15 02:15:04.253068: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory 2024-08-15 02:15:04.324588: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory 2024-08-15 02:15:04.324786: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... /tmpfs/tmp/ipykernel_93773/1984550333.py:14: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead. out_1 = tf.compat.v1.layers.dense(features, units=params['layer_1_size']) /tmpfs/tmp/ipykernel_93773/1984550333.py:15: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead. out_2 = tf.compat.v1.layers.dense(out_1, units=params['layer_2_size']) /tmpfs/tmp/ipykernel_93773/1984550333.py:16: UserWarning: `tf.layers.dense` is deprecated and will be removed in a future version. Please use `tf.keras.layers.Dense` instead. logits = tf.compat.v1.layers.dense(out_2, units=params['num_classes'])
Define the TF2 model.
class SimpleModel(tf.keras.Model):
def __init__(self, params, *args, **kwargs):
super(SimpleModel, self).__init__(*args, **kwargs)
# define the model
self.dense_1 = tf.keras.layers.Dense(params['layer_1_size'])
self.dense_2 = tf.keras.layers.Dense(params['layer_2_size'])
self.out = tf.keras.layers.Dense(params['num_classes'])
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(
initial_learning_rate=params['init_lr'],
decay_steps=params['decay_steps'],
end_learning_rate=params['end_lr'],
power=params['lr_power'])
self.optimizer = tf.keras.optimizers.legacy.SGD(learning_rate_fn)
self.compiled_loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
self.logs = {
'lr': [],
'loss': [],
'grads': [],
'weights': [],
'layer_out': []}
def call(self, inputs):
out_1 = self.dense_1(inputs)
out_2 = self.dense_2(out_1)
logits = self.out(out_2)
# log output features for every layer for comparison
layer_wise_out = {
'layer_1': out_1,
'layer_2': out_2,
'logits': logits}
self.logs['layer_out'].append(layer_wise_out)
return logits
def train_step(self, data):
x, y = data
with tf.GradientTape() as tape:
logits = self(x)
loss = self.compiled_loss(y, logits)
grads = tape.gradient(loss, self.trainable_weights)
# log training statistics
step = self.optimizer.iterations.numpy()
self.logs['lr'].append(self.optimizer.learning_rate(step).numpy())
self.logs['loss'].append(loss.numpy())
self.logs['grads'].append(grads)
self.logs['weights'].append(self.trainable_weights)
# update model
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return
Run the TF2 model in eager mode. Collect statistics for first 3 training steps for numerical equivalence comparison.
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
model_tf2 = SimpleModel(params)
for step in range(step_num):
model_tf2.train_step([fake_x, fake_y])
Compare numerical equivalence for first few training steps.
You can also check the Validating correctness & numerical equivalence notebook for additional advice for numerical equivalence.
np.testing.assert_allclose(model_tf1.logs['lr'], model_tf2.logs['lr'])
np.testing.assert_allclose(model_tf1.logs['loss'], model_tf2.logs['loss'])
for step in range(step_num):
for name in model_tf1.logs['layer_out'][step]:
np.testing.assert_allclose(
model_tf1.logs['layer_out'][step][name],
model_tf2.logs['layer_out'][step][name])
Unit tests
There are a few types of unit testing that can help debug your migration code.
- Single forward pass validation
- Model training numerical equivalence validation for a few steps
- Benchmark inference performance
- The trained model makes correct predictions on fixed and simple data points
You can use @parameterized.parameters
to test models with different configurations. Details with code sample.
Note that it's possible to run session APIs and eager execution in the same test case. The code snippets below show how.
import unittest
class TestNumericalEquivalence(unittest.TestCase):
# copied from code samples above
def setup(self):
# record statistics for 100 training steps
step_num = 100
# setup TF 1 model
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
# run TF1.x code in graph mode with context management
graph = tf.Graph()
with graph.as_default(), tf.compat.v1.Session(graph=graph) as sess:
self.model_tf1 = SimpleModelWrapper()
# build the model
inputs = tf.compat.v1.placeholder(tf.float32, shape=(None, params['input_size']))
labels = tf.compat.v1.placeholder(tf.float32, shape=(None, params['num_classes']))
spec = self.model_tf1.model_fn(inputs, labels, tf.estimator.ModeKeys.TRAIN, params)
train_op = spec.train_op
sess.run(tf.compat.v1.global_variables_initializer())
for step in range(step_num):
# log everything and update the model for one step
logs, _ = sess.run(
[self.model_tf1.logged_ops, train_op],
feed_dict={inputs: fake_x, labels: fake_y})
self.model_tf1.update_logs(logs)
# setup TF2 model
random_tool = v1.keras.utils.DeterministicRandomTestTool(mode='num_random_ops')
with random_tool.scope():
self.model_tf2 = SimpleModel(params)
for step in range(step_num):
self.model_tf2.train_step([fake_x, fake_y])
def test_learning_rate(self):
np.testing.assert_allclose(
self.model_tf1.logs['lr'],
self.model_tf2.logs['lr'])
def test_training_loss(self):
# adopt different tolerance strategies before and after 10 steps
first_n_step = 10
# absolute difference is limited below 1e-5
# set `equal_nan` to be False to detect potential NaN loss issues
abosolute_tolerance = 1e-5
np.testing.assert_allclose(
actual=self.model_tf1.logs['loss'][:first_n_step],
desired=self.model_tf2.logs['loss'][:first_n_step],
atol=abosolute_tolerance,
equal_nan=False)
# relative difference is limited below 5%
relative_tolerance = 0.05
np.testing.assert_allclose(self.model_tf1.logs['loss'][first_n_step:],
self.model_tf2.logs['loss'][first_n_step:],
rtol=relative_tolerance,
equal_nan=False)
Debugging tools
tf.print
tf.print vs print/logging.info
- With configurable arguments,
tf.print
can recursively display the first and last few elements of each dimension for printed tensors. Check the API docs for details. - For eager execution, both
print
andtf.print
print the value of the tensor. Butprint
may involve device-to-host copy, which can potentially slow down your code. - For graph mode including usage inside
tf.function
, you need to usetf.print
to print the actual tensor value.tf.print
is compiled into an op in the graph, whereasprint
andlogging.info
only log at tracing time, which is often not what you want. tf.print
also supports printing composite tensors liketf.RaggedTensor
andtf.sparse.SparseTensor
.- You can also use a callback to monitor metrics and variables. Please check how to use custom callbacks with logs dict and self.model attribute.
tf.print vs print inside tf.function
# `print` prints info of tensor object
# `tf.print` prints the tensor value
@tf.function
def dummy_func(num):
num += 1
print(num)
tf.print(num)
return num
_ = dummy_func(tf.constant([1.0]))
# Output:
# Tensor("add:0", shape=(1,), dtype=float32)
# [2]
Tensor("add:0", shape=(1,), dtype=float32) [2]
tf.distribute.Strategy
- If the
tf.function
containingtf.print
is executed on the workers, for example when usingTPUStrategy
orParameterServerStrategy
, you need to check worker/parameter server logs to find the printed values. - For
print
orlogging.info
, logs will be printed on the coordinator when usingParameterServerStrategy
, and logs will be printed on the STDOUT on worker0 when using TPUs.
tf.keras.Model
- When using Sequential and Functional API models, if you want to print values, e.g., model inputs or intermediate features after some layers, you have following options.
- Write a custom layer that
tf.print
the inputs. - Include the intermediate outputs you want to inspect in the model outputs.
- Write a custom layer that
tf.keras.layers.Lambda
layers have (de)serialization limitations. To avoid checkpoint loading issues, write a custom subclassed layer instead. Check the API docs for more details.- You can't
tf.print
intermediate outputs in atf.keras.callbacks.LambdaCallback
if you don't have access to the actual values, but instead only to the symbolic Keras tensor objects.
Option 1: write a custom layer
class PrintLayer(tf.keras.layers.Layer):
def call(self, inputs):
tf.print(inputs)
return inputs
def get_model():
inputs = tf.keras.layers.Input(shape=(1,))
out_1 = tf.keras.layers.Dense(4)(inputs)
out_2 = tf.keras.layers.Dense(1)(out_1)
# use custom layer to tf.print intermediate features
out_3 = PrintLayer()(out_2)
model = tf.keras.Model(inputs=inputs, outputs=out_3)
return model
model = get_model()
model.compile(optimizer="adam", loss="mse")
model.fit([1, 2, 3], [0.0, 0.0, 1.0])
[[-0.327884018] [-0.109294683] [-0.218589365]] 1/1 [==============================] - 0s 402ms/step - loss: 0.6077 <keras.callbacks.History at 0x7f67b584fa90>
Option 2: include the intermediate outputs you want to inspect in the model outputs.
Note that in such case, you may need some customizations to use Model.fit
.
def get_model():
inputs = tf.keras.layers.Input(shape=(1,))
out_1 = tf.keras.layers.Dense(4)(inputs)
out_2 = tf.keras.layers.Dense(1)(out_1)
# include intermediate values in model outputs
model = tf.keras.Model(
inputs=inputs,
outputs={
'inputs': inputs,
'out_1': out_1,
'out_2': out_2})
return model
pdb
You can use pdb both in terminal and Colab to inspect intermediate values for debugging.
Visualize graph with TensorBoard
You can examine the TensorFlow graph with TensorBoard. TensorBoard is also supported on colab. TensorBoard is a great tool to visualize summaries. You can use it to compare learning rate, model weights, gradient scale, training/validation metrics, or even model intermediate outputs between TF1.x model and migrated TF2 model through the training process and seeing if the values look as expected.
TensorFlow Profiler
TensorFlow Profiler can help you visualize the execution timeline on GPUs/TPUs. You can check out this Colab Demo for its basic usage.
MultiProcessRunner
MultiProcessRunner is a useful tool when debugging with MultiWorkerMirroredStrategy and ParameterServerStrategy. You can take a look at this concrete example for its usage.
Specifically for the cases of these two strategies, you are recommended to 1) not only have unit tests to cover their flow, 2) but also to attempt to reproduce failures using it in unit test to avoid launch real distributed job every time when they attempt a fix.