Jaringan Permusuhan Generatif Konvolusi Dalam

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Tutorial ini mendemonstrasikan cara menghasilkan gambar angka tulisan tangan menggunakan Deep Convolutional Generative Adversarial Network (DCGAN). Kode ditulis menggunakan Keras Sequential API dengan tf.GradientTape training loop.

Apa itu GAN?

Generative Adversarial Networks (GANs) adalah salah satu ide paling menarik dalam ilmu komputer saat ini. Dua model dilatih secara bersamaan oleh proses permusuhan. Seorang generator ("seniman") belajar membuat gambar yang terlihat nyata, sementara seorang diskriminator ("kritikus seni") belajar membedakan gambar asli dari yang palsu.

Diagram generator dan diskriminator

Selama pelatihan, generator secara bertahap menjadi lebih baik dalam menciptakan gambar yang terlihat nyata, sementara diskriminator menjadi lebih baik dalam membedakannya. Proses mencapai keseimbangan ketika diskriminator tidak bisa lagi membedakan gambar asli dan palsu.

Diagram kedua dari generator dan diskriminator

Notebook ini menunjukkan proses ini pada set data MNIST. Animasi berikut menunjukkan serangkaian gambar yang dihasilkan oleh generator seperti yang dilatih selama 50 zaman. Gambar dimulai sebagai noise acak, dan semakin menyerupai angka tulisan tangan dari waktu ke waktu.

keluaran sampel

Untuk mempelajari lebih lanjut tentang GAN, lihat kursus Intro to Deep Learning MIT.

Mempersiapkan

import tensorflow as tf
tf.__version__
'2.8.0-rc1'
# To generate GIFs
pip install imageio
pip install git+https://github.com/tensorflow/docs
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

Muat dan siapkan kumpulan data

Anda akan menggunakan dataset MNIST untuk melatih generator dan diskriminator. Generator akan menghasilkan angka tulisan tangan yang menyerupai data MNIST.

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5  # Normalize the images to [-1, 1]
BUFFER_SIZE = 60000
BATCH_SIZE = 256
# Batch and shuffle the data
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Buat modelnya

Baik generator maupun diskriminator ditentukan menggunakan Keras Sequential API .

Pembangkit

Generator menggunakan tf.keras.layers.Conv2DTranspose (upsampling) untuk menghasilkan gambar dari benih (noise acak). Mulailah dengan lapisan Dense yang mengambil benih ini sebagai input, lalu upsample beberapa kali hingga Anda mencapai ukuran gambar yang diinginkan 28x28x1. Perhatikan aktivasi tf.keras.layers.LeakyReLU untuk setiap layer, kecuali layer output yang menggunakan tanh.

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256)  # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

Gunakan generator (belum terlatih) untuk membuat gambar.

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')
<matplotlib.image.AxesImage at 0x7f6fe7a04b90>

png

Diskriminator

Diskriminator adalah pengklasifikasi gambar berbasis CNN.

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

Gunakan diskriminator (belum terlatih) untuk mengklasifikasikan gambar yang dihasilkan sebagai nyata atau palsu. Model akan dilatih untuk menghasilkan nilai positif untuk gambar asli, dan nilai negatif untuk gambar palsu.

discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)
tf.Tensor([[-0.00339105]], shape=(1, 1), dtype=float32)

Tentukan kerugian dan pengoptimal

Tentukan fungsi kerugian dan pengoptimal untuk kedua model.

# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

Kerugian diskriminator

Metode ini mengukur seberapa baik diskriminator mampu membedakan gambar asli dari palsu. Ini membandingkan prediksi diskriminator pada gambar nyata ke array 1s, dan prediksi diskriminator pada gambar palsu (dihasilkan) ke array 0s.

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

Kehilangan generator

Kerugian generator mengkuantifikasi seberapa baik ia mampu mengelabui diskriminator. Secara intuitif, jika generator berkinerja baik, diskriminator akan mengklasifikasikan gambar palsu sebagai nyata (atau 1). Di sini, bandingkan keputusan diskriminator pada gambar yang dihasilkan dengan array 1s.

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

Diskriminator dan pengoptimal generator berbeda karena Anda akan melatih dua jaringan secara terpisah.

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

Simpan pos pemeriksaan

Notebook ini juga menunjukkan cara menyimpan dan memulihkan model, yang dapat membantu jika tugas pelatihan yang berjalan lama terganggu.

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

Tentukan lingkaran pelatihan

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

# You will reuse this seed overtime (so it's easier)
# to visualize progress in the animated GIF)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

Loop pelatihan dimulai dengan generator menerima benih acak sebagai input. Benih itu digunakan untuk menghasilkan gambar. Diskriminator kemudian digunakan untuk mengklasifikasikan gambar asli (diambil dari set pelatihan) dan gambar palsu (diproduksi oleh generator). Kerugian dihitung untuk masing-masing model ini, dan gradien digunakan untuk memperbarui generator dan diskriminator.

# Notice the use of `tf.function`
# This annotation causes the function to be "compiled".
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as you go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

Hasilkan dan simpan gambar

def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False.
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

Latih modelnya

Panggil metode train() yang didefinisikan di atas untuk melatih generator dan diskriminator secara bersamaan. Catatan, melatih GAN bisa jadi rumit. Penting agar generator dan diskriminator tidak saling mengalahkan (misalnya, mereka berlatih dengan kecepatan yang sama).

Pada awal pelatihan, gambar yang dihasilkan terlihat seperti noise acak. Saat pelatihan berlangsung, angka yang dihasilkan akan terlihat semakin nyata. Setelah sekitar 50 zaman, mereka menyerupai angka MNIST. Ini mungkin memakan waktu sekitar satu menit / waktu dengan pengaturan default di Colab.

train(train_dataset, EPOCHS)

png

Kembalikan pos pemeriksaan terbaru.

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f6ee8136950>

Buat GIF

# Display a single image using the epoch number
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
display_image(EPOCHS)

png

Gunakan imageio untuk membuat gif animasi menggunakan gambar yang disimpan selama pelatihan.

anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  for filename in filenames:
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
import tensorflow_docs.vis.embed as embed
embed.embed_file(anim_file)

gif

Langkah selanjutnya

Tutorial ini telah menunjukkan kode lengkap yang diperlukan untuk menulis dan melatih GAN. Sebagai langkah selanjutnya, Anda mungkin ingin bereksperimen dengan kumpulan data yang berbeda, misalnya kumpulan data Atribut Wajah Selebriti Skala Besar (CelebA) yang tersedia di Kaggle . Untuk mempelajari lebih lanjut tentang GAN, lihat Tutorial NIPS 2016: Generative Adversarial Networks .