View on TensorFlow.org | Run in Google Colab | View on GitHub | Download notebook |
Overview
This guide provides a list of best practices for writing code using TensorFlow 2 (TF2), it is written for users who have recently switched over from TensorFlow 1 (TF1). Refer to the migrate section of the guide for more info on migrating your TF1 code to TF2.
Setup
Import TensorFlow and other dependencies for the examples in this guide.
import tensorflow as tf
import tensorflow_datasets as tfds
2023-10-04 01:22:53.526066: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:22:53.526110: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:22:53.526158: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Recommendations for idiomatic TensorFlow 2
Refactor your code into smaller modules
A good practice is to refactor your code into smaller functions that are called as needed. For best performance, you should try to decorate the largest blocks of computation that you can in a tf.function
(note that the nested python functions called by a tf.function
do not require their own separate decorations, unless you want to use different jit_compile
settings for the tf.function
). Depending on your use case, this could be multiple training steps or even your whole training loop. For inference use cases, it might be a single model forward pass.
Adjust the default learning rate for some tf.keras.optimizer
s
Some Keras optimizers have different learning rates in TF2. If you see a change in convergence behavior for your models, check the default learning rates.
There are no changes for optimizers.SGD
, optimizers.Adam
, or optimizers.RMSprop
.
The following default learning rates have changed:
optimizers.Adagrad
from0.01
to0.001
optimizers.Adadelta
from1.0
to0.001
optimizers.Adamax
from0.002
to0.001
optimizers.Nadam
from0.002
to0.001
Use tf.Module
s and Keras layers to manage variables
tf.Module
s and tf.keras.layers.Layer
s offer the convenient variables
and
trainable_variables
properties, which recursively gather up all dependent
variables. This makes it easy to manage variables locally to where they are
being used.
Keras layers/models inherit from tf.train.Checkpointable
and are integrated
with @tf.function
, which makes it possible to directly checkpoint or export
SavedModels from Keras objects. You do not necessarily have to use Keras'
Model.fit
API to take advantage of these integrations.
Read the section on transfer learning and fine-tuning in the Keras guide to learn how to collect a subset of relevant variables using Keras.
Combine tf.data.Dataset
s and tf.function
The TensorFlow Datasets package (tfds
) contains utilities for loading predefined datasets as tf.data.Dataset
objects. For this example, you can load the MNIST dataset using tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
2023-10-04 01:22:57.406511: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
Then prepare the data for training:
- Re-scale each image.
- Shuffle the order of the examples.
- Collect batches of images and labels.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
To keep the example short, trim the dataset to only return 5 batches:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2023-10-04 01:22:58.048011: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Use regular Python iteration to iterate over training data that fits in memory. Otherwise, tf.data.Dataset
is the best way to stream training data from disk. Datasets are iterables (not iterators), and work just like other Python iterables in eager execution. You can fully utilize dataset async prefetching/streaming features by wrapping your code in tf.function
, which replaces Python iteration with the equivalent graph operations using AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
If you use the Keras Model.fit
API, you won't have to worry about dataset
iteration.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Use Keras training loops
If you don't need low-level control of your training process, using Keras' built-in fit
, evaluate
, and predict
methods is recommended. These methods provide a uniform interface to train the model regardless of the implementation (sequential, functional, or sub-classed).
The advantages of these methods include:
- They accept Numpy arrays, Python generators and,
tf.data.Datasets
. - They apply regularization, and activation losses automatically.
- They support
tf.distribute
where the training code remains the same regardless of the hardware configuration. - They support arbitrary callables as losses and metrics.
- They support callbacks like
tf.keras.callbacks.TensorBoard
, and custom callbacks. - They are performant, automatically using TensorFlow graphs.
Here is an example of training a model using a Dataset
. For details on how this works, check out the tutorials.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 2s 44ms/step - loss: 1.6644 - accuracy: 0.4906 Epoch 2/5 2023-10-04 01:22:59.569439: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 9ms/step - loss: 0.5173 - accuracy: 0.9062 Epoch 3/5 2023-10-04 01:23:00.062308: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 9ms/step - loss: 0.3418 - accuracy: 0.9469 Epoch 4/5 2023-10-04 01:23:00.384057: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 8ms/step - loss: 0.2707 - accuracy: 0.9781 Epoch 5/5 2023-10-04 01:23:00.766486: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 8ms/step - loss: 0.2195 - accuracy: 0.9812 2023-10-04 01:23:01.120149: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 4ms/step - loss: 1.6036 - accuracy: 0.6250 Loss 1.6036441326141357, Accuracy 0.625 2023-10-04 01:23:01.572685: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Customize training and write your own loop
If Keras models work for you, but you need more flexibility and control of the training step or the outer training loops, you can implement your own training steps or even entire training loops. See the Keras guide on customizing fit
to learn more.
You can also implement many things as a tf.keras.callbacks.Callback
.
This method has many of the advantages mentioned previously, but gives you control of the train step and even the outer loop.
There are three steps to a standard training loop:
- Iterate over a Python generator or
tf.data.Dataset
to get batches of examples. - Use
tf.GradientTape
to collect gradients. - Use one of the
tf.keras.optimizers
to apply weight updates to the model's variables.
Remember:
- Always include a
training
argument on thecall
method of subclassed layers and models. - Make sure to call the model with the
training
argument set correctly. - Depending on usage, model variables may not exist until the model is run on a batch of data.
- You need to manually handle things like regularization losses for the model.
There is no need to run variable initializers or to add manual control dependencies. tf.function
handles automatic control dependencies and variable initialization on creation for you.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2023-10-04 01:23:02.652222: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2023-10-04 01:23:02.957452: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2023-10-04 01:23:03.632425: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2023-10-04 01:23:03.877866: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2023-10-04 01:23:04.197488: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Take advantage of tf.function
with Python control flow
tf.function
provides a way to convert data-dependent control flow into graph-mode
equivalents like tf.cond
and tf.while_loop
.
One common place where data-dependent control flow appears is in sequence
models. tf.keras.layers.RNN
wraps an RNN cell, allowing you to either
statically or dynamically unroll the recurrence. As an example, you could reimplement dynamic unroll as follows.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Read the tf.function
guide for a more information.
New-style metrics and losses
Metrics and losses are both objects that work eagerly and in tf.function
s.
A loss object is callable, and expects (y_true
, y_pred
) as arguments:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Use metrics to collect and display data
You can use tf.metrics
to aggregate data and tf.summary
to log summaries and redirect it to a writer using a context manager. The summaries are emitted directly to the writer which means that you must provide the step
value at the callsite.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Use tf.metrics
to aggregate data before logging them as summaries. Metrics are stateful; they accumulate values and return a cumulative result when you call the result
method (such as Mean.result
). Clear accumulated values with Model.reset_states
.
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Visualize the generated summaries by pointing TensorBoard to the summary log directory:
tensorboard --logdir /tmp/summaries
Use the tf.summary
API to write summary data for visualization in TensorBoard. For more info, read the tf.summary
guide.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2023-10-04 01:23:05.220607: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.176 accuracy: 0.994 2023-10-04 01:23:05.554495: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.153 accuracy: 0.991 2023-10-04 01:23:06.043597: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.134 accuracy: 0.994 2023-10-04 01:23:06.297768: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.108 accuracy: 1.000 Epoch: 4 loss: 0.095 accuracy: 1.000 2023-10-04 01:23:06.678292: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Keras metric names
Keras models are consistent about handling metric names. When you pass a string in the list of metrics, that exact string is used as the metric's name
. These names are visible in the history object returned by model.fit
, and in the logs passed to keras.callbacks
. is set to the string you passed in the metric list.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 9ms/step - loss: 0.1077 - acc: 0.9937 - accuracy: 0.9937 - my_accuracy: 0.9937 2023-10-04 01:23:07.849601: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Debugging
Use eager execution to run your code step-by-step to inspect shapes, data types and values. Certain APIs, like tf.function
, tf.keras
,
etc. are designed to use Graph execution, for performance and portability. When
debugging, use tf.config.run_functions_eagerly(True)
to use eager execution
inside this code.
For example:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
This also works inside Keras models and other APIs that support eager execution:
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Notes:
tf.keras.Model
methods such asfit
,evaluate
, andpredict
execute as graphs withtf.function
under the hood.When using
tf.keras.Model.compile
, setrun_eagerly = True
to disable theModel
logic from being wrapped in atf.function
.Use
tf.data.experimental.enable_debug_mode
to enable the debug mode fortf.data
. Read the API docs for more details.
Do not keep tf.Tensors
in your objects
These tensor objects might get created either in a tf.function
or in the eager context, and these tensors behave differently. Always use tf.Tensor
s only for intermediate values.
To track state, use tf.Variable
s as they are always usable from both contexts. Read the tf.Variable
guide to learn more.
Resources and further reading
Read the TF2 guides and tutorials to learn more about how to use TF2.
If you previously used TF1.x, it is highly recommended you migrate your code to TF2. Read the migration guides to learn more.