pix2pix: traduzione da immagine a immagine con un GAN condizionale

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno

Questo tutorial mostra come costruire e addestrare una rete contraddittoria generativa condizionale (cGAN) chiamata pix2pix che apprende una mappatura da immagini di input a immagini di output, come descritto in Traduzione da immagine a immagine con reti contraddittorio condizionale di Isola et al. (2017). pix2pix non è specifico dell'applicazione: può essere applicato a un'ampia gamma di attività, inclusa la sintesi di foto da mappe di etichette, la generazione di foto colorate da immagini in bianco e nero, la trasformazione di foto di Google Maps in immagini aeree e persino la trasformazione di schizzi in foto.

In questo esempio, la tua rete genererà immagini di facciate di edifici utilizzando il database delle facciate CMP fornito dal Center for Machine Perception dell'Università tecnica ceca di Praga . Per farla breve, utilizzerai una copia preelaborata di questo set di dati creato dagli autori di pix2pix.

In pix2pix cGAN, si condizionano le immagini di input e si generano le immagini di output corrispondenti. I cGAN sono stati proposti per la prima volta in Conditional Generative Adversarial Nets (Mirza e Osindero, 2014)

L'architettura della tua rete conterrà:

  • Un generatore con architettura basata su U-Net .
  • Un discriminatore rappresentato da un classificatore PatchGAN convoluzionale (proposto nel paper pix2pix ).

Nota che ogni epoca può richiedere circa 15 secondi su una singola GPU V100.

Di seguito sono riportati alcuni esempi dell'output generato da pix2pix cGAN dopo l'allenamento per 200 epoche sul dataset delle facciate (80k passi).

output di esempio_1output di esempio_2

Importa TensorFlow e altre librerie

import tensorflow as tf

import os
import pathlib
import time
import datetime

from matplotlib import pyplot as plt
from IPython import display

Carica il set di dati

Scarica i dati del database delle facciate CMP (30 MB). Ulteriori set di dati sono disponibili nello stesso formato qui . In Colab puoi selezionare altri set di dati dal menu a discesa. Si noti che alcuni degli altri set di dati sono significativamente più grandi ( edges2handbags è 8 GB).

dataset_name = "facades"
_URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'

path_to_zip = tf.keras.utils.get_file(
    fname=f"{dataset_name}.tar.gz",
    origin=_URL,
    extract=True)

path_to_zip  = pathlib.Path(path_to_zip)

PATH = path_to_zip.parent/dataset_name
Downloading data from http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/facades.tar.gz
30171136/30168306 [==============================] - 19s 1us/step
30179328/30168306 [==============================] - 19s 1us/step
list(PATH.parent.iterdir())
[PosixPath('/home/kbuilder/.keras/datasets/facades.tar.gz'),
 PosixPath('/home/kbuilder/.keras/datasets/YellowLabradorLooking_new.jpg'),
 PosixPath('/home/kbuilder/.keras/datasets/facades'),
 PosixPath('/home/kbuilder/.keras/datasets/mnist.npz')]

Ogni immagine originale ha dimensioni 256 x 512 e contiene due immagini 256 x 256 :

sample_image = tf.io.read_file(str(PATH / 'train/1.jpg'))
sample_image = tf.io.decode_jpeg(sample_image)
print(sample_image.shape)
(256, 512, 3)
plt.figure()
plt.imshow(sample_image)
<matplotlib.image.AxesImage at 0x7f35a3653c90>

png

È necessario separare le immagini della facciata dell'edificio reale dalle immagini dell'etichetta dell'architettura, tutte di dimensioni 256 x 256 .

Definire una funzione che carichi i file di immagine e produca due tensori di immagine:

def load(image_file):
  # Read and decode an image file to a uint8 tensor
  image = tf.io.read_file(image_file)
  image = tf.io.decode_jpeg(image)

  # Split each image tensor into two tensors:
  # - one with a real building facade image
  # - one with an architecture label image 
  w = tf.shape(image)[1]
  w = w // 2
  input_image = image[:, w:, :]
  real_image = image[:, :w, :]

  # Convert both images to float32 tensors
  input_image = tf.cast(input_image, tf.float32)
  real_image = tf.cast(real_image, tf.float32)

  return input_image, real_image

Traccia un campione delle immagini di input (immagine dell'etichetta dell'architettura) e reali (foto della facciata dell'edificio):

inp, re = load(str(PATH / 'train/100.jpg'))
# Casting to int for matplotlib to display the images
plt.figure()
plt.imshow(inp / 255.0)
plt.figure()
plt.imshow(re / 255.0)
<matplotlib.image.AxesImage at 0x7f35981a4910>

png

png

Come descritto nel documento pix2pix , è necessario applicare jittering e mirroring casuali per preelaborare il training set.

Definire diverse funzioni che:

  1. Ridimensiona ciascuna immagine da 256 x 256 a un'altezza e una larghezza maggiori: 286 x 286 .
  2. Ritaglialo casualmente a 256 x 256 .
  3. Capovolgere casualmente l'immagine orizzontalmente, cioè da sinistra a destra (rispecchiamento casuale).
  4. Normalizza le immagini nell'intervallo [-1, 1] .
# The facade training set consist of 400 images
BUFFER_SIZE = 400
# The batch size of 1 produced better results for the U-Net in the original pix2pix experiment
BATCH_SIZE = 1
# Each image is 256x256 in size
IMG_WIDTH = 256
IMG_HEIGHT = 256
def resize(input_image, real_image, height, width):
  input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
  real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  return input_image, real_image
def random_crop(input_image, real_image):
  stacked_image = tf.stack([input_image, real_image], axis=0)
  cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image[0], cropped_image[1]
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
  input_image = (input_image / 127.5) - 1
  real_image = (real_image / 127.5) - 1

  return input_image, real_image
@tf.function()
def random_jitter(input_image, real_image):
  # Resizing to 286x286
  input_image, real_image = resize(input_image, real_image, 286, 286)

  # Random cropping back to 256x256
  input_image, real_image = random_crop(input_image, real_image)

  if tf.random.uniform(()) > 0.5:
    # Random mirroring
    input_image = tf.image.flip_left_right(input_image)
    real_image = tf.image.flip_left_right(real_image)

  return input_image, real_image

È possibile ispezionare parte dell'output preelaborato:

plt.figure(figsize=(6, 6))
for i in range(4):
  rj_inp, rj_re = random_jitter(inp, re)
  plt.subplot(2, 2, i + 1)
  plt.imshow(rj_inp / 255.0)
  plt.axis('off')
plt.show()

png

Dopo aver verificato che il caricamento e la preelaborazione funzionino, definiamo un paio di funzioni di supporto che caricano e preelaborano i set di addestramento e test:

def load_image_train(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = random_jitter(input_image, real_image)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image
def load_image_test(image_file):
  input_image, real_image = load(image_file)
  input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
  input_image, real_image = normalize(input_image, real_image)

  return input_image, real_image

Crea una pipeline di input con tf.data

train_dataset = tf.data.Dataset.list_files(str(PATH / 'train/*.jpg'))
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)
try:
  test_dataset = tf.data.Dataset.list_files(str(PATH / 'test/*.jpg'))
except tf.errors.InvalidArgumentError:
  test_dataset = tf.data.Dataset.list_files(str(PATH / 'val/*.jpg'))
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(BATCH_SIZE)

Costruisci il generatore

Il generatore del tuo pix2pix cGAN è un U-Net modificato . Una U-Net è composta da un encoder (downsampler) e un decoder (upsampler). (Puoi saperne di più nel tutorial sulla segmentazione delle immagini e sul sito Web del progetto U-Net .)

  • Ogni blocco nell'encoder è: Convoluzione -> Normalizzazione batch -> Leaky ReLU
  • Ogni blocco nel decoder è: Convoluzione trasposta -> Normalizzazione batch -> Dropout (applicato ai primi 3 blocchi) -> ReLU
  • Ci sono connessioni salta tra l'encoder e il decoder (come in U-Net).

Definire il downsampler (encoder):

OUTPUT_CHANNELS = 3
def downsample(filters, size, apply_batchnorm=True):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_batchnorm:
    result.add(tf.keras.layers.BatchNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result
down_model = downsample(3, 4)
down_result = down_model(tf.expand_dims(inp, 0))
print (down_result.shape)
(1, 128, 128, 3)

Definire l'upsampler (decodificatore):

def upsample(filters, size, apply_dropout=False):
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
    tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

  result.add(tf.keras.layers.BatchNormalization())

  if apply_dropout:
      result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result
up_model = upsample(3, 4)
up_result = up_model(down_result)
print (up_result.shape)
(1, 256, 256, 3)

Definire il generatore con il downsampler e l'upsampler:

def Generator():
  inputs = tf.keras.layers.Input(shape=[256, 256, 3])

  down_stack = [
    downsample(64, 4, apply_batchnorm=False),  # (batch_size, 128, 128, 64)
    downsample(128, 4),  # (batch_size, 64, 64, 128)
    downsample(256, 4),  # (batch_size, 32, 32, 256)
    downsample(512, 4),  # (batch_size, 16, 16, 512)
    downsample(512, 4),  # (batch_size, 8, 8, 512)
    downsample(512, 4),  # (batch_size, 4, 4, 512)
    downsample(512, 4),  # (batch_size, 2, 2, 512)
    downsample(512, 4),  # (batch_size, 1, 1, 512)
  ]

  up_stack = [
    upsample(512, 4, apply_dropout=True),  # (batch_size, 2, 2, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 4, 4, 1024)
    upsample(512, 4, apply_dropout=True),  # (batch_size, 8, 8, 1024)
    upsample(512, 4),  # (batch_size, 16, 16, 1024)
    upsample(256, 4),  # (batch_size, 32, 32, 512)
    upsample(128, 4),  # (batch_size, 64, 64, 256)
    upsample(64, 4),  # (batch_size, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  # (batch_size, 256, 256, 3)

  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = tf.keras.layers.Concatenate()([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Visualizza l'architettura del modello del generatore:

generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

png

Testare il generatore:

gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.image.AxesImage at 0x7f35cfd20610>

png

Definire la perdita del generatore

I GAN apprendono una perdita che si adatta ai dati, mentre i cGAN apprendono una perdita strutturata che penalizza una possibile struttura che differisce dall'output di rete e dall'immagine target, come descritto nel documento pix2pix .

  • La perdita del generatore è una perdita di entropia incrociata sigmoidea delle immagini generate e di una serie di immagini.
  • Il documento pix2pix menziona anche la perdita L1, che è un MAE (errore medio assoluto) tra l'immagine generata e l'immagine target.
  • Ciò consente all'immagine generata di diventare strutturalmente simile all'immagine di destinazione.
  • La formula per calcolare la perdita totale del generatore è gan_loss + LAMBDA * l1_loss , dove LAMBDA = 100 . Questo valore è stato deciso dagli autori dell'articolo.
LAMBDA = 100
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def generator_loss(disc_generated_output, gen_output, target):
  gan_loss = loss_object(tf.ones_like(disc_generated_output), disc_generated_output)

  # Mean absolute error
  l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

  total_gen_loss = gan_loss + (LAMBDA * l1_loss)

  return total_gen_loss, gan_loss, l1_loss

La procedura di addestramento per il generatore è la seguente:

Immagine di aggiornamento del generatore

Costruisci il discriminatore

Il discriminatore in pix2pix cGAN è un classificatore PatchGAN convoluzionale: tenta di classificare se ogni patch di immagine è reale o meno, come descritto nel documento pix2pix .

  • Ogni blocco nel discriminatore è: Convoluzione -> Normalizzazione batch -> Leaky ReLU.
  • La forma dell'output dopo l'ultimo livello è (batch_size, 30, 30, 1) .
  • Ogni patch di immagine 30 x 30 dell'output classifica una porzione 70 x 70 dell'immagine di input.
  • Il discriminatore riceve 2 ingressi:
    • L'immagine di input e l'immagine di destinazione, che dovrebbe classificare come reale.
    • L'immagine di input e l'immagine generata (l'output del generatore), che dovrebbe classificare come fake.
    • Utilizzare tf.concat([inp, tar], axis=-1) per concatenare questi 2 input insieme.

Definiamo il discriminatore:

def Discriminator():
  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[256, 256, 3], name='input_image')
  tar = tf.keras.layers.Input(shape=[256, 256, 3], name='target_image')

  x = tf.keras.layers.concatenate([inp, tar])  # (batch_size, 256, 256, channels*2)

  down1 = downsample(64, 4, False)(x)  # (batch_size, 128, 128, 64)
  down2 = downsample(128, 4)(down1)  # (batch_size, 64, 64, 128)
  down3 = downsample(256, 4)(down2)  # (batch_size, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (batch_size, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                kernel_initializer=initializer,
                                use_bias=False)(zero_pad1)  # (batch_size, 31, 31, 512)

  batchnorm1 = tf.keras.layers.BatchNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (batch_size, 33, 33, 512)

  last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                kernel_initializer=initializer)(zero_pad2)  # (batch_size, 30, 30, 1)

  return tf.keras.Model(inputs=[inp, tar], outputs=last)

Visualizza l'architettura del modello discriminatore:

discriminator = Discriminator()
tf.keras.utils.plot_model(discriminator, show_shapes=True, dpi=64)

png

Metti alla prova il discriminatore:

disc_out = discriminator([inp[tf.newaxis, ...], gen_output], training=False)
plt.imshow(disc_out[0, ..., -1], vmin=-20, vmax=20, cmap='RdBu_r')
plt.colorbar()
<matplotlib.colorbar.Colorbar at 0x7f35cec82c50>

png

Definire la perdita del discriminatore

  • La funzione discriminator_loss accetta 2 input: immagini reali e immagini generate .
  • real_loss è una perdita di entropia incrociata sigmoidea delle immagini reali e di una matrice di quelle (poiché queste sono le immagini reali) .
  • generated_loss è una perdita di entropia incrociata sigmoidea delle immagini generate e un array di zeri (poiché queste sono le immagini false) .
  • Il total_loss è la somma di real_loss e generated_loss .
def discriminator_loss(disc_real_output, disc_generated_output):
  real_loss = loss_object(tf.ones_like(disc_real_output), disc_real_output)

  generated_loss = loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss

Di seguito è illustrata la procedura di formazione per il discriminatore.

Per saperne di più sull'architettura e gli iperparametri si può fare riferimento al paper pix2pix .

Immagine di aggiornamento discriminatore

Definisci gli ottimizzatori e un checkpoint-saver

generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Genera immagini

Scrivi una funzione per tracciare alcune immagini durante l'allenamento.

  • Passa le immagini dal set di prova al generatore.
  • Il generatore tradurrà quindi l'immagine di input nell'output.
  • L'ultimo passo è tracciare le previsioni e voilà !
def generate_images(model, test_input, tar):
  prediction = model(test_input, training=True)
  plt.figure(figsize=(15, 15))

  display_list = [test_input[0], tar[0], prediction[0]]
  title = ['Input Image', 'Ground Truth', 'Predicted Image']

  for i in range(3):
    plt.subplot(1, 3, i+1)
    plt.title(title[i])
    # Getting the pixel values in the [0, 1] range to plot.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

Testare la funzione:

for example_input, example_target in test_dataset.take(1):
  generate_images(generator, example_input, example_target)

png

Formazione

  • Per ogni input di esempio viene generato un output.
  • Il discriminatore riceve l' input_image e l'immagine generata come primo input. Il secondo input è input_image e target_image .
  • Quindi, calcola il generatore e la perdita del discriminatore.
  • Quindi, calcola i gradienti di perdita rispetto sia al generatore che alle variabili (ingressi) discriminatori e applica quelli all'ottimizzatore.
  • Infine, registra le perdite su TensorBoard.
log_dir="logs/"

summary_writer = tf.summary.create_file_writer(
  log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
    tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

Il ciclo di allenamento vero e proprio. Poiché questo tutorial può essere eseguito su più di un set di dati e le dimensioni dei set di dati variano notevolmente, il ciclo di addestramento è impostato per funzionare in fasi anziché in epoche.

  • Itera sul numero di passaggi.
  • Ogni 10 passi stampa un punto ( . ).
  • Ogni 1k passaggi: cancella il display ed esegui generate_images per mostrare l'avanzamento.
  • Ogni 5k passi: salva un checkpoint.
def fit(train_ds, test_ds, steps):
  example_input, example_target = next(iter(test_ds.take(1)))
  start = time.time()

  for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
    if (step) % 1000 == 0:
      display.clear_output(wait=True)

      if step != 0:
        print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')

      start = time.time()

      generate_images(generator, example_input, example_target)
      print(f"Step: {step//1000}k")

    train_step(input_image, target, step)

    # Training step
    if (step+1) % 10 == 0:
      print('.', end='', flush=True)


    # Save (checkpoint) the model every 5k steps
    if (step + 1) % 5000 == 0:
      checkpoint.save(file_prefix=checkpoint_prefix)

Questo ciclo di formazione salva i registri che puoi visualizzare in TensorBoard per monitorare l'avanzamento della formazione.

Se lavori su una macchina locale, avvierai un processo TensorBoard separato. Quando si lavora su un notebook, avviare il visualizzatore prima di iniziare la formazione per monitorare con TensorBoard.

Per avviare il visualizzatore, incolla quanto segue in una cella di codice:

%load_ext tensorboard
%tensorboard --logdir {log_dir}

Infine, esegui il ciclo di formazione:

fit(train_dataset, test_dataset, steps=40000)
Time taken for 1000 steps: 36.53 sec

png

Step: 39k
....................................................................................................

Se desideri condividere pubblicamente i risultati di TensorBoard , puoi caricare i log su TensorBoard.dev copiando quanto segue in una cella di codice.

tensorboard dev upload --logdir {log_dir}

È possibile visualizzare i risultati di un'esecuzione precedente di questo notebook su TensorBoard.dev .

TensorBoard.dev è un'esperienza gestita per l'hosting, il monitoraggio e la condivisione di esperimenti di machine learning con tutti.

Può anche essere incluso in linea usando un <iframe> :

display.IFrame(
    src="https://tensorboard.dev/experiment/lZ0C6FONROaUMfjYkVyJqw",
    width="100%",
    height="1000px")

L'interpretazione dei log è più sottile quando si addestra un GAN (o un cGAN come pix2pix) rispetto a un semplice modello di classificazione o regressione. Cose da cercare:

  • Verificare che né il modello generatore né quello discriminatore abbiano "vinto". Se gen_gan_loss o disc_loss diventano molto bassi, è un indicatore che questo modello sta dominando l'altro e che non stai addestrando con successo il modello combinato.
  • Il valore log(2) = 0.69 è un buon punto di riferimento per queste perdite, in quanto indica una perplessità di 2 - il discriminatore è, in media, ugualmente incerto sulle due opzioni.
  • Per disc_loss , un valore inferiore a 0.69 significa che il discriminatore sta andando meglio di random sull'insieme combinato di immagini reali e generate.
  • Per gen_gan_loss , un valore inferiore a 0.69 significa che il generatore sta facendo meglio di random nell'ingannare il discriminatore.
  • Con il progredire della formazione, gen_l1_loss dovrebbe diminuire.

Ripristina l'ultimo checkpoint e testa la rete

ls {checkpoint_dir}
checkpoint          ckpt-5.data-00000-of-00001
ckpt-1.data-00000-of-00001  ckpt-5.index
ckpt-1.index            ckpt-6.data-00000-of-00001
ckpt-2.data-00000-of-00001  ckpt-6.index
ckpt-2.index            ckpt-7.data-00000-of-00001
ckpt-3.data-00000-of-00001  ckpt-7.index
ckpt-3.index            ckpt-8.data-00000-of-00001
ckpt-4.data-00000-of-00001  ckpt-8.index
ckpt-4.index
# Restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f35cfd6b8d0>

Genera alcune immagini usando il set di prova

# Run the trained model on a few examples from the test set
for inp, tar in test_dataset.take(5):
  generate_images(generator, inp, tar)

png

png

png

png

png