CycleGAN

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

این نوت بوک ترجمه جفت نشده تصویر به تصویر را با استفاده از GAN های شرطی نشان می دهد، همانطور که در ترجمه تصویر به تصویر بدون جفت با استفاده از شبکه های متخاصم با چرخه ، که به نام CycleGAN نیز شناخته می شود، توضیح داده شده است. این مقاله روشی را پیشنهاد می‌کند که می‌تواند ویژگی‌های یک حوزه تصویر را ضبط کند و بفهمد که چگونه این ویژگی‌ها می‌توانند به حوزه تصویر دیگری ترجمه شوند، همه در غیاب هر گونه مثال آموزشی جفتی.

این نوت بوک فرض می کند که شما با Pix2Pix آشنا هستید که می توانید در آموزش Pix2Pix با آن آشنا شوید. کد CycleGAN مشابه است، تفاوت اصلی تابع ضرر اضافی و استفاده از داده های آموزشی جفت نشده است.

CycleGAN از کاهش ثبات چرخه برای فعال کردن آموزش بدون نیاز به داده های جفت استفاده می کند. به عبارت دیگر، می تواند از یک دامنه به دامنه دیگر بدون نگاشت یک به یک بین منبع و دامنه مقصد ترجمه کند.

این امکان را برای انجام بسیاری از کارهای جالب مانند بهبود عکس، رنگ آمیزی تصویر، انتقال سبک و غیره باز می کند. تنها چیزی که نیاز دارید منبع و مجموعه داده هدف است (که به سادگی فهرستی از تصاویر است).

تصویر خروجی 1تصویر خروجی 2

خط لوله ورودی را تنظیم کنید

بسته tensorflow_examples را نصب کنید که امکان وارد کردن ژنراتور و تفکیک کننده را فراهم می کند.

pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

خط لوله ورودی

این آموزش مدلی را آموزش می دهد تا از تصاویر اسب به تصاویر گورخر ترجمه کند. شما می توانید این مجموعه داده و موارد مشابه را در اینجا بیابید.

همانطور که در مقاله ذکر شد، لرزش تصادفی و آینه سازی را به مجموعه داده آموزشی اعمال کنید. اینها برخی از تکنیک های تقویت تصویر هستند که از برازش بیش از حد جلوگیری می کنند.

این شبیه کاری است که در pix2pix انجام شد

  • در جیترینگ تصادفی، اندازه تصویر به 286 x 286 تغییر می‌کند و سپس به‌طور تصادفی به 256 x 256 برش داده می‌شود.
  • در انعکاس تصادفی، تصویر به طور تصادفی به صورت افقی برگردانده می شود، یعنی از چپ به راست.
dataset, metadata = tfds.load('cycle_gan/horse2zebra',
                              with_info=True, as_supervised=True)

train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
def random_crop(image):
  cropped_image = tf.image.random_crop(
      image, size=[IMG_HEIGHT, IMG_WIDTH, 3])

  return cropped_image
# normalizing the images to [-1, 1]
def normalize(image):
  image = tf.cast(image, tf.float32)
  image = (image / 127.5) - 1
  return image
def random_jitter(image):
  # resizing to 286 x 286 x 3
  image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

  # randomly cropping to 256 x 256 x 3
  image = random_crop(image)

  # random mirroring
  image = tf.image.random_flip_left_right(image)

  return image
def preprocess_image_train(image, label):
  image = random_jitter(image)
  image = normalize(image)
  return image
def preprocess_image_test(image, label):
  image = normalize(image)
  return image
train_horses = train_horses.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

train_zebras = train_zebras.cache().map(
    preprocess_image_train, num_parallel_calls=AUTOTUNE).shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_horses = test_horses.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_zebras = test_zebras.map(
    preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))
2022-01-26 02:38:15.762422: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-01-26 02:38:19.927846: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
-l10n-
plt.subplot(121)
plt.title('Horse')
plt.imshow(sample_horse[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Horse with random jitter')
plt.imshow(random_jitter(sample_horse[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f7cf83e0050>

png

plt.subplot(121)
plt.title('Zebra')
plt.imshow(sample_zebra[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('Zebra with random jitter')
plt.imshow(random_jitter(sample_zebra[0]) * 0.5 + 0.5)
<matplotlib.image.AxesImage at 0x7f7cf8139490>

png

مدل‌های Pix2Pix را وارد کرده و دوباره استفاده کنید

مولد و تمایز مورد استفاده در Pix2Pix را از طریق بسته نصب شده tensorflow_examples وارد کنید.

معماری مدل استفاده شده در این آموزش بسیار شبیه به آنچه در pix2pix استفاده شده است. برخی از تفاوت ها عبارتند از:

2 ژنراتور (G و F) و 2 تشخیص دهنده (X و Y) در اینجا آموزش می بینند.

  • ژنراتور G یاد می گیرد که تصویر X را به تصویر Y تبدیل کند. \((G: X -> Y)\)
  • ژنراتور F می آموزد که تصویر Y را به تصویر X تبدیل کند. \((F: Y -> X)\)
  • D_X یاد می گیرد که بین تصویر X و تصویر X ( F(Y) ) تفاوت قائل شود.
  • D_Y یاد می گیرد که بین تصویر Y و تصویر Y تولید شده ( G(X) تفاوت قائل شود.

مدل سیکلگان

OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)
to_zebra = generator_g(sample_horse)
to_horse = generator_f(sample_zebra)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(imgs[i][0] * 0.5 + 0.5)
  else:
    plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

png

plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real zebra?')
plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real horse?')
plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

plt.show()

png

توابع از دست دادن

در CycleGAN، هیچ داده جفتی برای آموزش وجود ندارد، از این رو هیچ تضمینی وجود ندارد که ورودی x و جفت هدف y در طول آموزش معنادار باشند. بنابراین برای اینکه شبکه نقشه برداری صحیح را بیاموزد، نویسندگان از دست دادن ثبات چرخه را پیشنهاد می کنند.

از دست دادن تفکیک کننده و از دست دادن ژنراتور مشابه موارد استفاده شده در pix2pix است.

LAMBDA = 10
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

سازگاری چرخه به این معنی است که نتیجه باید نزدیک به ورودی اصلی باشد. به عنوان مثال، اگر یک جمله را از انگلیسی به فرانسوی ترجمه کنید، و سپس آن را از فرانسوی به انگلیسی ترجمه کنید، آنگاه جمله حاصل باید مانند جمله اصلی باشد.

در از دست دادن ثبات چرخه،

  • تصویر \(X\) از طریق ژنراتور \(G\) ارسال می شود که تصویر تولید شده \(\hat{Y}\)را به دست می دهد.
  • تصویر تولید شده \(\hat{Y}\) از طریق ژنراتور \(F\) می شود که تصویر چرخه ای \(\hat{X}\)ایجاد می کند.
  • میانگین خطای مطلق بین \(X\) و \(\hat{X}\)محاسبه می شود.

\[forward\ cycle\ consistency\ loss: X -> G(X) -> F(G(X)) \sim \hat{X}\]

\[backward\ cycle\ consistency\ loss: Y -> F(Y) -> G(F(Y)) \sim \hat{Y}\]

از دست دادن چرخه

def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

  return LAMBDA * loss1

همانطور که در بالا نشان داده شد، مولد \(G\) مسئول ترجمه تصویر \(X\) به تصویر \(Y\)است. از دست دادن هویت می‌گوید که، اگر تصویر \(Y\) را به ژنراتور \(G\)، باید تصویر واقعی \(Y\) یا چیزی نزدیک به تصویر \(Y\)ایجاد کند.

اگر مدل گورخر به اسب را روی اسب یا مدل اسب به گورخر را روی گورخر اجرا کنید، نباید تصویر را زیاد تغییر دهید زیرا تصویر قبلاً شامل کلاس هدف است.

\[Identity\ loss = |G(Y) - Y| + |F(X) - X|\]

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

بهینه سازها را برای همه مولدها و متمایزکننده ها راه اندازی کنید.

generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

ایست های بازرسی

checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

آموزش

EPOCHS = 40
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

اگرچه حلقه آموزشی پیچیده به نظر می رسد، از چهار مرحله اساسی تشکیل شده است:

  • پیش بینی ها را دریافت کنید
  • ضرر را محاسبه کنید.
  • شیب ها را با استفاده از پس انتشار محاسبه کنید.
  • گرادیان ها را روی بهینه ساز اعمال کنید.
@tf.function
def train_step(real_x, real_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    # Generator G translates X -> Y
    # Generator F translates Y -> X.

    fake_y = generator_g(real_x, training=True)
    cycled_x = generator_f(fake_y, training=True)

    fake_x = generator_f(real_y, training=True)
    cycled_y = generator_g(fake_x, training=True)

    # same_x and same_y are used for identity loss.
    same_x = generator_f(real_x, training=True)
    same_y = generator_g(real_y, training=True)

    disc_real_x = discriminator_x(real_x, training=True)
    disc_real_y = discriminator_y(real_y, training=True)

    disc_fake_x = discriminator_x(fake_x, training=True)
    disc_fake_y = discriminator_y(fake_y, training=True)

    # calculate the loss
    gen_g_loss = generator_loss(disc_fake_y)
    gen_f_loss = generator_loss(disc_fake_x)

    total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

    # Total generator loss = adversarial loss + cycle loss
    total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
    total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

    disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
    disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

  # Calculate the gradients for generator and discriminator
  generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
  generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

  discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
  discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

  # Apply the gradients to the optimizer
  generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

  generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

  discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))

  discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # Using a consistent image (sample_horse) so that the progress of the model
  # is clearly visible.
  generate_images(generator_g, sample_horse)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

png

Saving checkpoint for epoch 40 at ./checkpoints/train/ckpt-8
Time taken for epoch 40 is 166.64579939842224 sec

با استفاده از مجموعه داده آزمایشی ایجاد کنید

# Run the trained model on the test dataset
for inp in test_horses.take(5):
  generate_images(generator_g, inp)

png

png

png

png

png

مراحل بعدی

این آموزش نحوه پیاده سازی CycleGAN را با شروع از مولد و تشخیص دهنده پیاده سازی شده در آموزش Pix2Pix نشان می دهد. به عنوان گام بعدی، می توانید از مجموعه داده متفاوتی از TensorFlow Datasets استفاده کنید.

همچنین می‌توانید برای بهبود نتایج، دوره‌های بیشتری را آموزش دهید، یا می‌توانید به جای ژنراتور U-Net که در اینجا استفاده می‌شود، مولد ResNet اصلاح‌شده مورد استفاده در مقاله را اجرا کنید.