Rete contraddittoria generativa profonda convoluzionale

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno

Questo tutorial mostra come generare immagini di cifre scritte a mano utilizzando una Deep Convolutional Generative Adversarial Network (DCGAN). Il codice viene scritto utilizzando l' API Keras Sequential con un ciclo di addestramento tf.GradientTape .

Cosa sono i GAN?

Le reti generative contraddittorio (GAN) sono una delle idee più interessanti dell'informatica odierna. Due modelli sono addestrati simultaneamente da un processo contraddittorio. Un generatore ("l'artista") impara a creare immagini che sembrano reali, mentre un discriminatore ("il critico d'arte") impara a distinguere le immagini reali dai falsi.

Un diagramma di un generatore e discriminatore

Durante l'allenamento, il generatore diventa progressivamente più bravo a creare immagini che sembrano reali, mentre il discriminatore diventa più bravo a distinguerle. Il processo raggiunge l'equilibrio quando il discriminatore non riesce più a distinguere le immagini reali da quelle false.

Un secondo diagramma di un generatore e discriminatore

Questo quaderno mostra questo processo sul set di dati MNIST. L'animazione seguente mostra una serie di immagini prodotte dal generatore mentre veniva addestrato per 50 epoche. Le immagini iniziano come rumore casuale e nel tempo assomigliano sempre più a cifre scritte a mano.

uscita del campione

Per ulteriori informazioni sui GAN, vedere il corso Intro to Deep Learning del MIT.

Impostare

import tensorflow as tf
tf.__version__
'2.8.0-rc1'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Carica e prepara il set di dati

Utilizzerai il set di dati MNIST per addestrare il generatore e il discriminatore. Il generatore genererà cifre scritte a mano simili ai dati MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Crea i modelli

Sia il generatore che il discriminatore sono definiti utilizzando l' API Keras Sequential .

Il generatore

Il generatore utilizza tf.keras.layers.Conv2DTranspose (upsampling) per produrre un'immagine da un seme (rumore casuale). Inizia con un livello Dense che prende questo seme come input, quindi sovracampiona più volte fino a raggiungere la dimensione dell'immagine desiderata di 28x28x1. Notare l'attivazione di tf.keras.layers.LeakyReLU per ogni livello, ad eccezione del livello di output che usa tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Usa il generatore (non ancora addestrato) per creare un'immagine.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f6fe7a04b90>

png

Il discriminatore

Il discriminatore è un classificatore di immagini basato sulla CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Usa il discriminatore (non ancora addestrato) per classificare le immagini generate come reali o false. Il modello verrà addestrato per produrre valori positivi per immagini reali e valori negativi per immagini false.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00339105]], shape=(1, 1), dtype=float32)

Definire la perdita e gli ottimizzatori

Definisci le funzioni di perdita e gli ottimizzatori per entrambi i modelli.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Perdita discriminante

Questo metodo quantifica quanto bene il discriminatore è in grado di distinguere le immagini reali da quelle false. Confronta le previsioni del discriminatore su immagini reali con un array di 1 e le previsioni del discriminatore su immagini false (generate) con un array di 0.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Perdita del generatore

La perdita del generatore quantifica quanto bene è stato in grado di ingannare il discriminatore. Intuitivamente, se il generatore funziona bene, il discriminatore classificherà le immagini false come reali (o 1). Qui, confronta le decisioni dei discriminatori sulle immagini generate con una matrice di 1s.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Il discriminatore e gli ottimizzatori del generatore sono diversi poiché addestrerai due reti separatamente.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Salva i checkpoint

Questo notebook mostra anche come salvare e ripristinare i modelli, che possono essere utili nel caso in cui un'attività di formazione di lunga durata venga interrotta.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Definisci il ciclo di allenamento

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Il ciclo di addestramento inizia con il generatore che riceve un seme casuale come input. Quel seme viene utilizzato per produrre un'immagine. Il discriminatore viene quindi utilizzato per classificare le immagini reali (disegnate dal training set) e le immagini false (prodotte dal generatore). La perdita viene calcolata per ciascuno di questi modelli e i gradienti vengono utilizzati per aggiornare il generatore e il discriminatore.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Genera e salva immagini

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Allena il modello

Chiama il metodo train() definito sopra per addestrare il generatore e il discriminatore contemporaneamente. Nota, l'addestramento dei GAN può essere complicato. È importante che il generatore e il discriminatore non si prevalgano a vicenda (ad esempio, che si allenino a una velocità simile).

All'inizio del training, le immagini generate appaiono come rumore casuale. Con il progredire della formazione, le cifre generate appariranno sempre più reali. Dopo circa 50 epoche, assomigliano alle cifre MNIST. Questa operazione potrebbe richiedere circa un minuto/epoca con le impostazioni predefinite su Colab.

train(train_dataset, EPOCHS)

png

Ripristina l'ultimo checkpoint.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6ee8136950>

Crea una GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Usa imageio per creare una gif animata usando le immagini salvate durante l'allenamento.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Prossimi passi

Questo tutorial ha mostrato il codice completo necessario per scrivere e addestrare un GAN. Come passaggio successivo, potresti voler sperimentare un set di dati diverso, ad esempio il set di dati Celeb Faces Attributes (CelebA) su larga scala disponibile su Kaggle . Per saperne di più sui GAN, vedere il Tutorial NIPS 2016: Generative Adversarial Networks .