TensorFlow.org で表示 | Google Colabで実行 | GitHub でソースを表示 | ノートブックをダウンロード | TF Hub モデルを参照 |
Boundless モデル Colab へようこそ!このノートブックでは、画像のモデルを実行し、結果を可視化するまでの手順を説明します。
概要
Boundless は画像外挿用のモデルです。このモデルは画像を受け取り、その画像の一部分 (1/2、1/4、3/4) を内部的にマスクし、マスクされた部分を補完します。詳細は Boundless: 画像拡張のための敵対的生成ネットワークまたは TensorFlow Hub のモデルに関するドキュメントをご覧ください。
インポートとセットアップ
基本のインポートから始めます。
import tensorflow as tf
import tensorflow_hub as hub
from io import BytesIO
from PIL import Image as PilImage
import numpy as np
from matplotlib import pyplot as plt
from six.moves.urllib.request import urlopen
2024-01-11 18:37:34.307869: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:37:34.307929: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:37:34.309531: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
入力する画像を読み取る
util メソッドを作成して画像を読み込み、モデル用にフォーマットしてみましょう (257x257x3)。このメソッドは歪みを避けるために画像を正方形にトリミングし、ローカルの画像やインターネットからの画像を使用することも可能です。
def read_image(filename):
fd = None
if(filename.startswith('http')):
fd = urlopen(filename)
else:
fd = tf.io.gfile.GFile(filename, 'rb')
pil_image = PilImage.open(fd)
width, height = pil_image.size
# crop to make the image square
pil_image = pil_image.crop((0, 0, height, height))
pil_image = pil_image.resize((257,257),PilImage.LANCZOS)
image_unscaled = np.array(pil_image)
image_np = np.expand_dims(
image_unscaled.astype(np.float32) / 255., axis=0)
return image_np
可視化メソッド
また可視化メソッドを作成して、モデルが生成したマスクされたバージョンと「塗りつぶされた」バージョンの両方で、元の画像が並べて表示されるようにします。
def visualize_output_comparison(img_original, img_masked, img_filled):
plt.figure(figsize=(24,12))
plt.subplot(131)
plt.imshow((np.squeeze(img_original)))
plt.title("Original", fontsize=24)
plt.axis('off')
plt.subplot(132)
plt.imshow((np.squeeze(img_masked)))
plt.title("Masked", fontsize=24)
plt.axis('off')
plt.subplot(133)
plt.imshow((np.squeeze(img_filled)))
plt.title("Generated", fontsize=24)
plt.axis('off')
plt.show()
画像を読み込む
ここではサンプル画像を読み込みますが、独自の画像を Colab にアップロードしてご自由にお試しください。なお、モデルには人物画像に関する制限があるので注意してください。
wikimedia = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/31/Nusfjord_road%2C_2010_09.jpg/800px-Nusfjord_road%2C_2010_09.jpg"
# wikimedia = "https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/Beech_forest_M%C3%A1tra_in_winter.jpg/640px-Beech_forest_M%C3%A1tra_in_winter.jpg"
# wikimedia = "https://upload.wikimedia.org/wikipedia/commons/thumb/b/b2/Marmolada_Sunset.jpg/640px-Marmolada_Sunset.jpg"
# wikimedia = "https://upload.wikimedia.org/wikipedia/commons/thumb/9/9d/Aegina_sunset.jpg/640px-Aegina_sunset.jpg"
input_img = read_image(wikimedia)
TensorFlow Hub からモデルを選択する
TensorFlow Hub には、Half (1/2)、Quarter (1/4)、Three Quarters (3/4) という 3 つのバージョンの Boundless モデルがあります。以下のセルでは、そのうちの任意の 1 つを選択して独自のイメージで試すことができます。別のバージョンのモデルで試す場合には、試したいバージョンを選択してから以下のセルを実行してください。
Model Selection
model_name = 'Boundless Quarter' # @param ['Boundless Half', 'Boundless Quarter', 'Boundless Three Quarters']
model_handle_map = {
'Boundless Half' : 'https://tfhub.dev/google/boundless/half/1',
'Boundless Quarter' : 'https://tfhub.dev/google/boundless/quarter/1',
'Boundless Three Quarters' : 'https://tfhub.dev/google/boundless/three_quarter/1'
}
model_handle = model_handle_map[model_name]
必要なモデルを選択したら、それを TensorFlow Hub から読み込みましょう。
注意: モデルのドキュメントを読むには、ブラウザをモデルのハンドルにポイントします。
print("Loading model {} ({})".format(model_name, model_handle))
model = hub.load(model_handle)
Loading model Boundless Quarter (https://tfhub.dev/google/boundless/quarter/1)
推論を行う
Boundless モデルには 2 つの出力があります。
- マスクされた入力画像の出力
- 外挿して完成させるマスクされた画像
これら 2 つの画像を可視化し、比較表示します。
result = model.signatures['default'](tf.constant(input_img))
generated_image = result['default']
masked_image = result['masked_image']
visualize_output_comparison(input_img, masked_image, generated_image)