مشاهده در TensorFlow.org | در Google Colab اجرا شود | در GitHub مشاهده کنید | دانلود دفترچه یادداشت | مدل TF Hub را ببینید |
این نشان می دهد COLAB استفاده از TensorFlow توپی کاروان پیشرفته سوپر قطعنامه زایشی خصمانه شبکه (توسط Xintao وانگ et.al.) [ مقاله ] [ کد ]
برای تقویت تصویر (ترجیحاً تصاویر کوچک شده به صورت دو مکعبی).
مدل آموزش داده شده بر روی مجموعه داده های DIV2K (بر روی تصاویر نمونه برداری شده دو مکعبی) روی وصله های تصویری با اندازه 128 x 128.
آماده سازی محیط
import os
import time
from PIL import Image
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
os.environ["TFHUB_DOWNLOAD_PROGRESS"] = "True"
wget "https://user-images.githubusercontent.com/12981474/40157448-eff91f06-5953-11e8-9a37-f6b5693fa03f.png" -O original.png
--2021-11-05 12:46:51-- https://user-images.githubusercontent.com/12981474/40157448-eff91f06-5953-11e8-9a37-f6b5693fa03f.png Resolving user-images.githubusercontent.com (user-images.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.111.133, ... Connecting to user-images.githubusercontent.com (user-images.githubusercontent.com)|185.199.109.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 34146 (33K) [image/png] Saving to: ‘original.png’ original.png 100%[===================>] 33.35K --.-KB/s in 0.002s 2021-11-05 12:46:51 (13.2 MB/s) - ‘original.png’ saved [34146/34146]
# Declaring Constants
IMAGE_PATH = "original.png"
SAVED_MODEL_PATH = "https://tfhub.dev/captain-pool/esrgan-tf2/1"
تعریف توابع کمکی
def preprocess_image(image_path):
""" Loads image from path and preprocesses to make it model ready
Args:
image_path: Path to the image file
"""
hr_image = tf.image.decode_image(tf.io.read_file(image_path))
# If PNG, remove the alpha channel. The model only supports
# images with 3 color channels.
if hr_image.shape[-1] == 4:
hr_image = hr_image[...,:-1]
hr_size = (tf.convert_to_tensor(hr_image.shape[:-1]) // 4) * 4
hr_image = tf.image.crop_to_bounding_box(hr_image, 0, 0, hr_size[0], hr_size[1])
hr_image = tf.cast(hr_image, tf.float32)
return tf.expand_dims(hr_image, 0)
def save_image(image, filename):
"""
Saves unscaled Tensor Images.
Args:
image: 3D image tensor. [height, width, channels]
filename: Name of the file to save.
"""
if not isinstance(image, Image.Image):
image = tf.clip_by_value(image, 0, 255)
image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
image.save("%s.jpg" % filename)
print("Saved as %s.jpg" % filename)
%matplotlib inline
def plot_image(image, title=""):
"""
Plots images from image tensors.
Args:
image: 3D image tensor. [height, width, channels].
title: Title to display in the plot.
"""
image = np.asarray(image)
image = tf.clip_by_value(image, 0, 255)
image = Image.fromarray(tf.cast(image, tf.uint8).numpy())
plt.imshow(image)
plt.axis("off")
plt.title(title)
انجام وضوح فوق العاده تصاویر بارگذاری شده از مسیر
hr_image = preprocess_image(IMAGE_PATH)
# Plotting Original Resolution image
plot_image(tf.squeeze(hr_image), title="Original Image")
save_image(tf.squeeze(hr_image), filename="Original Image")
Saved as Original Image.jpg
model = hub.load(SAVED_MODEL_PATH)
Downloaded https://tfhub.dev/captain-pool/esrgan-tf2/1, Total size: 20.60MB
start = time.time()
fake_image = model(hr_image)
fake_image = tf.squeeze(fake_image)
print("Time Taken: %f" % (time.time() - start))
Time Taken: 2.695235
# Plotting Super Resolution Image
plot_image(tf.squeeze(fake_image), title="Super Resolution")
save_image(tf.squeeze(fake_image), filename="Super Resolution")
Saved as Super Resolution.jpg
ارزیابی عملکرد مدل
!wget "https://lh4.googleusercontent.com/-Anmw5df4gj0/AAAAAAAAAAI/AAAAAAAAAAc/6HxU8XFLnQE/photo.jpg64" -O test.jpg
IMAGE_PATH = "test.jpg"
--2021-11-05 12:47:03-- https://lh4.googleusercontent.com/-Anmw5df4gj0/AAAAAAAAAAI/AAAAAAAAAAc/6HxU8XFLnQE/photo.jpg64 Resolving lh4.googleusercontent.com (lh4.googleusercontent.com)... 64.233.188.132, 2404:6800:4008:c06::84 Connecting to lh4.googleusercontent.com (lh4.googleusercontent.com)|64.233.188.132|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 84897 (83K) [image/jpeg] Saving to: ‘test.jpg’ test.jpg 100%[===================>] 82.91K --.-KB/s in 0.001s 2021-11-05 12:47:04 (94.8 MB/s) - ‘test.jpg’ saved [84897/84897]
# Defining helper functions
def downscale_image(image):
"""
Scales down images using bicubic downsampling.
Args:
image: 3D or 4D tensor of preprocessed image
"""
image_size = []
if len(image.shape) == 3:
image_size = [image.shape[1], image.shape[0]]
else:
raise ValueError("Dimension mismatch. Can work only on single image.")
image = tf.squeeze(
tf.cast(
tf.clip_by_value(image, 0, 255), tf.uint8))
lr_image = np.asarray(
Image.fromarray(image.numpy())
.resize([image_size[0] // 4, image_size[1] // 4],
Image.BICUBIC))
lr_image = tf.expand_dims(lr_image, 0)
lr_image = tf.cast(lr_image, tf.float32)
return lr_image
hr_image = preprocess_image(IMAGE_PATH)
lr_image = downscale_image(tf.squeeze(hr_image))
# Plotting Low Resolution Image
plot_image(tf.squeeze(lr_image), title="Low Resolution")
model = hub.load(SAVED_MODEL_PATH)
start = time.time()
fake_image = model(lr_image)
fake_image = tf.squeeze(fake_image)
print("Time Taken: %f" % (time.time() - start))
Time Taken: 1.161794
plot_image(tf.squeeze(fake_image), title="Super Resolution")
# Calculating PSNR wrt Original Image
psnr = tf.image.psnr(
tf.clip_by_value(fake_image, 0, 255),
tf.clip_by_value(hr_image, 0, 255), max_val=255)
print("PSNR Achieved: %f" % psnr)
PSNR Achieved: 28.029171
مقایسه اندازه خروجی ها در کنار یکدیگر
plt.rcParams['figure.figsize'] = [15, 10]
fig, axes = plt.subplots(1, 3)
fig.tight_layout()
plt.subplot(131)
plot_image(tf.squeeze(hr_image), title="Original")
plt.subplot(132)
fig.tight_layout()
plot_image(tf.squeeze(lr_image), "x4 Bicubic")
plt.subplot(133)
fig.tight_layout()
plot_image(tf.squeeze(fake_image), "Super Resolution")
plt.savefig("ESRGAN_DIV2K.jpg", bbox_inches="tight")
print("PSNR: %f" % psnr)
PSNR: 28.029171