View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
This tutorial contains a minimal implementation of DeepDream, as described in this blog post by Alexander Mordvintsev.
DeepDream is an experiment that visualizes the patterns learned by a neural network. Similar to when a child watches clouds and tries to interpret random shapes, DeepDream over-interprets and enhances the patterns it sees in an image.
It does so by forwarding an image through the network, then calculating the gradient of the image with respect to the activations of a particular layer. The image is then modified to increase these activations, enhancing the patterns seen by the network, and resulting in a dream-like image. This process was dubbed "Inceptionism" (a reference to InceptionNet, and the movie Inception).
Let's demonstrate how you can make a neural network "dream" and enhance the surreal patterns it sees in an image.
import tensorflow as tf
2024-08-16 03:30:23.118686: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-08-16 03:30:23.140085: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-08-16 03:30:23.146564: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
import numpy as np
import matplotlib as mpl
import IPython.display as display
import PIL.Image
Choose an image to dream-ify
For this tutorial, let's use an image of a labrador.
url = 'https://storage.googleapis.com/download.tensorflow.org/example_images/YellowLabradorLooking_new.jpg'
# Download an image and read it into a NumPy array.
def download(url, max_dim=None):
name = url.split('/')[-1]
image_path = tf.keras.utils.get_file(name, origin=url)
img = PIL.Image.open(image_path)
if max_dim:
img.thumbnail((max_dim, max_dim))
return np.array(img)
# Normalize an image
def deprocess(img):
img = 255*(img + 1.0)/2.0
return tf.cast(img, tf.uint8)
# Display an image
def show(img):
display.display(PIL.Image.fromarray(np.array(img)))
# Downsizing the image makes it easier to work with.
original_img = download(url, max_dim=500)
show(original_img)
display.display(display.HTML('Image cc-by: <a "href=https://commons.wikimedia.org/wiki/File:Felis_catus-cat_on_snow.jpg">Von.grzanka</a>'))
Prepare the feature extraction model
Download and prepare a pre-trained image classification model. You will use InceptionV3 which is similar to the model originally used in DeepDream. Note that any pre-trained model will work, although you will have to adjust the layer names below if you change this.
base_model = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1723779025.863071 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.866858 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.870100 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.873272 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.884930 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.888405 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.891380 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.894290 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.897661 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.901177 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.904095 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779025.907002 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.131871 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.133971 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.136078 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.138148 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.140193 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.142169 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.144143 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.146665 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.148566 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.150549 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.152512 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.154476 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.193123 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.195202 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.197345 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.199477 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.201454 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.203434 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.205405 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.207394 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.209335 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.211774 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.214221 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723779027.216600 149110 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 87910968/87910968 ━━━━━━━━━━━━━━━━━━━━ 1s 0us/step
The idea in DeepDream is to choose a layer (or layers) and maximize the "loss" in a way that the image increasingly "excites" the layers. The complexity of the features incorporated depends on layers chosen by you, i.e, lower layers produce strokes or simple patterns, while deeper layers give sophisticated features in images, or even whole objects.
The InceptionV3 architecture is quite large (for a graph of the model architecture see TensorFlow's research repo). For DeepDream, the layers of interest are those where the convolutions are concatenated. There are 11 of these layers in InceptionV3, named 'mixed0' though 'mixed10'. Using different layers will result in different dream-like images. Deeper layers respond to higher-level features (such as eyes and faces), while earlier layers respond to simpler features (such as edges, shapes, and textures). Feel free to experiment with the layers selected below, but keep in mind that deeper layers (those with a higher index) will take longer to train on since the gradient computation is deeper.
# Maximize the activations of these layers
names = ['mixed3', 'mixed5']
layers = [base_model.get_layer(name).output for name in names]
# Create the feature extraction model
dream_model = tf.keras.Model(inputs=base_model.input, outputs=layers)
Calculate loss
The loss is the sum of the activations in the chosen layers. The loss is normalized at each layer so the contribution from larger layers does not outweigh smaller layers. Normally, loss is a quantity you wish to minimize via gradient descent. In DeepDream, you will maximize this loss via gradient ascent.
def calc_loss(img, model):
# Pass forward the image through the model to retrieve the activations.
# Converts the image into a batch of size 1.
img_batch = tf.expand_dims(img, axis=0)
layer_activations = model(img_batch)
if len(layer_activations) == 1:
layer_activations = [layer_activations]
losses = []
for act in layer_activations:
loss = tf.math.reduce_mean(act)
losses.append(loss)
return tf.reduce_sum(losses)
Gradient ascent
Once you have calculated the loss for the chosen layers, all that is left is to calculate the gradients with respect to the image, and add them to the original image.
Adding the gradients to the image enhances the patterns seen by the network. At each step, you will have created an image that increasingly excites the activations of certain layers in the network.
The method that does this, below, is wrapped in a tf.function
for performance. It uses an input_signature
to ensure that the function is not retraced for different image sizes or steps
/step_size
values. See the Concrete functions guide for details.
class DeepDream(tf.Module):
def __init__(self, model):
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
tf.TensorSpec(shape=[], dtype=tf.int32),
tf.TensorSpec(shape=[], dtype=tf.float32),)
)
def __call__(self, img, steps, step_size):
print("Tracing")
loss = tf.constant(0.0)
for n in tf.range(steps):
with tf.GradientTape() as tape:
# This needs gradients relative to `img`
# `GradientTape` only watches `tf.Variable`s by default
tape.watch(img)
loss = calc_loss(img, self.model)
# Calculate the gradient of the loss with respect to the pixels of the input image.
gradients = tape.gradient(loss, img)
# Normalize the gradients.
gradients /= tf.math.reduce_std(gradients) + 1e-8
# In gradient ascent, the "loss" is maximized so that the input image increasingly "excites" the layers.
# You can update the image by directly adding the gradients (because they're the same shape!)
img = img + gradients*step_size
img = tf.clip_by_value(img, -1, 1)
return loss, img
deepdream = DeepDream(dream_model)
Main Loop
def run_deep_dream_simple(img, steps=100, step_size=0.01):
# Convert from uint8 to the range expected by the model.
img = tf.keras.applications.inception_v3.preprocess_input(img)
img = tf.convert_to_tensor(img)
step_size = tf.convert_to_tensor(step_size)
steps_remaining = steps
step = 0
while steps_remaining:
if steps_remaining>100:
run_steps = tf.constant(100)
else:
run_steps = tf.constant(steps_remaining)
steps_remaining -= run_steps
step += run_steps
loss, img = deepdream(img, run_steps, tf.constant(step_size))
display.clear_output(wait=True)
show(deprocess(img))
print ("Step {}, loss {}".format(step, loss))
result = deprocess(img)
display.clear_output(wait=True)
show(result)
return result
dream_img = run_deep_dream_simple(img=original_img,
steps=100, step_size=0.01)
Taking it up an octave
Pretty good, but there are a few issues with this first attempt:
- The output is noisy (this could be addressed with a
tf.image.total_variation
loss). - The image is low resolution.
- The patterns appear like they're all happening at the same granularity.
One approach that addresses all these problems is applying gradient ascent at different scales. This will allow patterns generated at smaller scales to be incorporated into patterns at higher scales and filled in with additional detail.
To do this you can perform the previous gradient ascent approach, then increase the size of the image (which is referred to as an octave), and repeat this process for multiple octaves.
import time
start = time.time()
OCTAVE_SCALE = 1.30
img = tf.constant(np.array(original_img))
base_shape = tf.shape(img)[:-1]
float_base_shape = tf.cast(base_shape, tf.float32)
for n in range(-2, 3):
new_shape = tf.cast(float_base_shape*(OCTAVE_SCALE**n), tf.int32)
img = tf.image.resize(img, new_shape).numpy()
img = run_deep_dream_simple(img=img, steps=50, step_size=0.01)
display.clear_output(wait=True)
img = tf.image.resize(img, base_shape)
img = tf.image.convert_image_dtype(img/255.0, dtype=tf.uint8)
show(img)
end = time.time()
end-start
11.953380346298218
Optional: Scaling up with tiles
One thing to consider is that as the image increases in size, so will the time and memory necessary to perform the gradient calculation. The above octave implementation will not work on very large images, or many octaves.
To avoid this issue you can split the image into tiles and compute the gradient for each tile.
Applying random shifts to the image before each tiled computation prevents tile seams from appearing.
Start by implementing the random shift:
def random_roll(img, maxroll):
# Randomly shift the image to avoid tiled boundaries.
shift = tf.random.uniform(shape=[2], minval=-maxroll, maxval=maxroll, dtype=tf.int32)
img_rolled = tf.roll(img, shift=shift, axis=[0,1])
return shift, img_rolled
shift, img_rolled = random_roll(np.array(original_img), 512)
show(img_rolled)
Here is a tiled equivalent of the deepdream
function defined earlier:
class TiledGradients(tf.Module):
def __init__(self, model):
self.model = model
@tf.function(
input_signature=(
tf.TensorSpec(shape=[None,None,3], dtype=tf.float32),
tf.TensorSpec(shape=[2], dtype=tf.int32),
tf.TensorSpec(shape=[], dtype=tf.int32),)
)
def __call__(self, img, img_size, tile_size=512):
shift, img_rolled = random_roll(img, tile_size)
# Initialize the image gradients to zero.
gradients = tf.zeros_like(img_rolled)
# Skip the last tile, unless there's only one tile.
xs = tf.range(0, img_size[1], tile_size)[:-1]
if not tf.cast(len(xs), bool):
xs = tf.constant([0])
ys = tf.range(0, img_size[0], tile_size)[:-1]
if not tf.cast(len(ys), bool):
ys = tf.constant([0])
for x in xs:
for y in ys:
# Calculate the gradients for this tile.
with tf.GradientTape() as tape:
# This needs gradients relative to `img_rolled`.
# `GradientTape` only watches `tf.Variable`s by default.
tape.watch(img_rolled)
# Extract a tile out of the image.
img_tile = img_rolled[y:y+tile_size, x:x+tile_size]
loss = calc_loss(img_tile, self.model)
# Update the image gradients for this tile.
gradients = gradients + tape.gradient(loss, img_rolled)
# Undo the random shift applied to the image and its gradients.
gradients = tf.roll(gradients, shift=-shift, axis=[0,1])
# Normalize the gradients.
gradients /= tf.math.reduce_std(gradients) + 1e-8
return gradients
get_tiled_gradients = TiledGradients(dream_model)
Putting this together gives a scalable, octave-aware deepdream implementation:
def run_deep_dream_with_octaves(img, steps_per_octave=100, step_size=0.01,
octaves=range(-2,3), octave_scale=1.3):
base_shape = tf.shape(img)
img = tf.keras.utils.img_to_array(img)
img = tf.keras.applications.inception_v3.preprocess_input(img)
initial_shape = img.shape[:-1]
img = tf.image.resize(img, initial_shape)
for octave in octaves:
# Scale the image based on the octave
new_size = tf.cast(tf.convert_to_tensor(base_shape[:-1]), tf.float32)*(octave_scale**octave)
new_size = tf.cast(new_size, tf.int32)
img = tf.image.resize(img, new_size)
for step in range(steps_per_octave):
gradients = get_tiled_gradients(img, new_size)
img = img + gradients*step_size
img = tf.clip_by_value(img, -1, 1)
if step % 10 == 0:
display.clear_output(wait=True)
show(deprocess(img))
print ("Octave {}, Step {}".format(octave, step))
result = deprocess(img)
return result
img = run_deep_dream_with_octaves(img=original_img, step_size=0.01)
display.clear_output(wait=True)
img = tf.image.resize(img, base_shape)
img = tf.image.convert_image_dtype(img/255.0, dtype=tf.uint8)
show(img)
Much better! Play around with the number of octaves, octave scale, and activated layers to change how your DeepDream-ed image looks.
Readers might also be interested in TensorFlow Lucid which expands on ideas introduced in this tutorial to visualize and interpret neural networks.