W ramach tej współpracy wypróbujesz wiele modeli klasyfikacji obrazów z TensorFlow Hub i zdecydujesz, który z nich jest najlepszy dla Twojego przypadku użycia.
Ponieważ TF Hub zachęca spójną konwencję wejściowe do modeli, które działają na zdjęciach, jest to łatwe do eksperymentowania z różnych architektur, aby znaleźć taki, który najlepiej pasuje do Twoich potrzeb.
Zobacz na TensorFlow.org | Uruchom w Google Colab | Zobacz na GitHub | Pobierz notatnik | Zobacz modele piasty TF |
import tensorflow as tf
import tensorflow_hub as hub
import requests
from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np
Funkcje pomocnicze do ładowania obrazu (ukryte)
original_image_cache = {}
def preprocess_image(image):
image = np.array(image)
# reshape into shape [batch_size, height, width, num_channels]
img_reshaped = tf.reshape(image, [1, image.shape[0], image.shape[1], image.shape[2]])
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
image = tf.image.convert_image_dtype(img_reshaped, tf.float32)
return image
def load_image_from_url(img_url):
"""Returns an image with shape [1, height, width, num_channels]."""
user_agent = {'User-agent': 'Colab Sample (https://tensorflow.org)'}
response = requests.get(img_url, headers=user_agent)
image = Image.open(BytesIO(response.content))
image = preprocess_image(image)
return image
def load_image(image_url, image_size=256, dynamic_size=False, max_dynamic_size=512):
"""Loads and preprocesses images."""
# Cache image file locally.
if image_url in original_image_cache:
img = original_image_cache[image_url]
elif image_url.startswith('https://'):
img = load_image_from_url(image_url)
else:
fd = tf.io.gfile.GFile(image_url, 'rb')
img = preprocess_image(Image.open(fd))
original_image_cache[image_url] = img
# Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1].
img_raw = img
if tf.reduce_max(img) > 1.0:
img = img / 255.
if len(img.shape) == 3:
img = tf.stack([img, img, img], axis=-1)
if not dynamic_size:
img = tf.image.resize_with_pad(img, image_size, image_size)
elif img.shape[1] > max_dynamic_size or img.shape[2] > max_dynamic_size:
img = tf.image.resize_with_pad(img, max_dynamic_size, max_dynamic_size)
return img, img_raw
def show_image(image, title=''):
image_size = image.shape[1]
w = (image_size * 6) // 320
plt.figure(figsize=(w, w))
plt.imshow(image[0], aspect='equal')
plt.axis('off')
plt.title(title)
plt.show()
Wybierz model klasyfikacji obrazu. Następnie ustawiane są niektóre zmienne wewnętrzne, a plik etykiet jest pobierany i przygotowywany do użycia.
Istnieją pewne różnice techniczne między modelami, takie jak inny rozmiar danych wejściowych, rozmiar modelu, dokładność i czas wnioskowania. Tutaj możesz zmienić używany model, aż znajdziesz ten najbardziej odpowiedni dla twojego przypadku użycia.
Uchwyt (url) modelu jest drukowany dla Twojej wygody. Więcej dokumentacji na temat każdego modelu jest tam dostępne.
Wybierz model klasyfikacji obrazów
image_size = 224
dynamic_size = False
model_name = "efficientnetv2-s" # @param ['efficientnetv2-s', 'efficientnetv2-m', 'efficientnetv2-l', 'efficientnetv2-s-21k', 'efficientnetv2-m-21k', 'efficientnetv2-l-21k', 'efficientnetv2-xl-21k', 'efficientnetv2-b0-21k', 'efficientnetv2-b1-21k', 'efficientnetv2-b2-21k', 'efficientnetv2-b3-21k', 'efficientnetv2-s-21k-ft1k', 'efficientnetv2-m-21k-ft1k', 'efficientnetv2-l-21k-ft1k', 'efficientnetv2-xl-21k-ft1k', 'efficientnetv2-b0-21k-ft1k', 'efficientnetv2-b1-21k-ft1k', 'efficientnetv2-b2-21k-ft1k', 'efficientnetv2-b3-21k-ft1k', 'efficientnetv2-b0', 'efficientnetv2-b1', 'efficientnetv2-b2', 'efficientnetv2-b3', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'bit_s-r50x1', 'inception_v3', 'inception_resnet_v2', 'resnet_v1_50', 'resnet_v1_101', 'resnet_v1_152', 'resnet_v2_50', 'resnet_v2_101', 'resnet_v2_152', 'nasnet_large', 'nasnet_mobile', 'pnasnet_large', 'mobilenet_v2_100_224', 'mobilenet_v2_130_224', 'mobilenet_v2_140_224', 'mobilenet_v3_small_100_224', 'mobilenet_v3_small_075_224', 'mobilenet_v3_large_100_224', 'mobilenet_v3_large_075_224']
model_handle_map = {
"efficientnetv2-s": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/classification/2",
"efficientnetv2-m": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_m/classification/2",
"efficientnetv2-l": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_l/classification/2",
"efficientnetv2-s-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_s/classification/2",
"efficientnetv2-m-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_m/classification/2",
"efficientnetv2-l-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_l/classification/2",
"efficientnetv2-xl-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_xl/classification/2",
"efficientnetv2-b0-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b0/classification/2",
"efficientnetv2-b1-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b1/classification/2",
"efficientnetv2-b2-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b2/classification/2",
"efficientnetv2-b3-21k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_b3/classification/2",
"efficientnetv2-s-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_s/classification/2",
"efficientnetv2-m-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_m/classification/2",
"efficientnetv2-l-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_l/classification/2",
"efficientnetv2-xl-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_xl/classification/2",
"efficientnetv2-b0-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/classification/2",
"efficientnetv2-b1-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b1/classification/2",
"efficientnetv2-b2-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b2/classification/2",
"efficientnetv2-b3-21k-ft1k": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b3/classification/2",
"efficientnetv2-b0": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b0/classification/2",
"efficientnetv2-b1": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b1/classification/2",
"efficientnetv2-b2": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b2/classification/2",
"efficientnetv2-b3": "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_b3/classification/2",
"efficientnet_b0": "https://tfhub.dev/tensorflow/efficientnet/b0/classification/1",
"efficientnet_b1": "https://tfhub.dev/tensorflow/efficientnet/b1/classification/1",
"efficientnet_b2": "https://tfhub.dev/tensorflow/efficientnet/b2/classification/1",
"efficientnet_b3": "https://tfhub.dev/tensorflow/efficientnet/b3/classification/1",
"efficientnet_b4": "https://tfhub.dev/tensorflow/efficientnet/b4/classification/1",
"efficientnet_b5": "https://tfhub.dev/tensorflow/efficientnet/b5/classification/1",
"efficientnet_b6": "https://tfhub.dev/tensorflow/efficientnet/b6/classification/1",
"efficientnet_b7": "https://tfhub.dev/tensorflow/efficientnet/b7/classification/1",
"bit_s-r50x1": "https://tfhub.dev/google/bit/s-r50x1/ilsvrc2012_classification/1",
"inception_v3": "https://tfhub.dev/google/imagenet/inception_v3/classification/4",
"inception_resnet_v2": "https://tfhub.dev/google/imagenet/inception_resnet_v2/classification/4",
"resnet_v1_50": "https://tfhub.dev/google/imagenet/resnet_v1_50/classification/4",
"resnet_v1_101": "https://tfhub.dev/google/imagenet/resnet_v1_101/classification/4",
"resnet_v1_152": "https://tfhub.dev/google/imagenet/resnet_v1_152/classification/4",
"resnet_v2_50": "https://tfhub.dev/google/imagenet/resnet_v2_50/classification/4",
"resnet_v2_101": "https://tfhub.dev/google/imagenet/resnet_v2_101/classification/4",
"resnet_v2_152": "https://tfhub.dev/google/imagenet/resnet_v2_152/classification/4",
"nasnet_large": "https://tfhub.dev/google/imagenet/nasnet_large/classification/4",
"nasnet_mobile": "https://tfhub.dev/google/imagenet/nasnet_mobile/classification/4",
"pnasnet_large": "https://tfhub.dev/google/imagenet/pnasnet_large/classification/4",
"mobilenet_v2_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/4",
"mobilenet_v2_130_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_130_224/classification/4",
"mobilenet_v2_140_224": "https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/classification/4",
"mobilenet_v3_small_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_small_100_224/classification/5",
"mobilenet_v3_small_075_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_small_075_224/classification/5",
"mobilenet_v3_large_100_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_large_100_224/classification/5",
"mobilenet_v3_large_075_224": "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/classification/5",
}
model_image_size_map = {
"efficientnetv2-s": 384,
"efficientnetv2-m": 480,
"efficientnetv2-l": 480,
"efficientnetv2-b0": 224,
"efficientnetv2-b1": 240,
"efficientnetv2-b2": 260,
"efficientnetv2-b3": 300,
"efficientnetv2-s-21k": 384,
"efficientnetv2-m-21k": 480,
"efficientnetv2-l-21k": 480,
"efficientnetv2-xl-21k": 512,
"efficientnetv2-b0-21k": 224,
"efficientnetv2-b1-21k": 240,
"efficientnetv2-b2-21k": 260,
"efficientnetv2-b3-21k": 300,
"efficientnetv2-s-21k-ft1k": 384,
"efficientnetv2-m-21k-ft1k": 480,
"efficientnetv2-l-21k-ft1k": 480,
"efficientnetv2-xl-21k-ft1k": 512,
"efficientnetv2-b0-21k-ft1k": 224,
"efficientnetv2-b1-21k-ft1k": 240,
"efficientnetv2-b2-21k-ft1k": 260,
"efficientnetv2-b3-21k-ft1k": 300,
"efficientnet_b0": 224,
"efficientnet_b1": 240,
"efficientnet_b2": 260,
"efficientnet_b3": 300,
"efficientnet_b4": 380,
"efficientnet_b5": 456,
"efficientnet_b6": 528,
"efficientnet_b7": 600,
"inception_v3": 299,
"inception_resnet_v2": 299,
"mobilenet_v2_100_224": 224,
"mobilenet_v2_130_224": 224,
"mobilenet_v2_140_224": 224,
"nasnet_large": 331,
"nasnet_mobile": 224,
"pnasnet_large": 331,
"resnet_v1_50": 224,
"resnet_v1_101": 224,
"resnet_v1_152": 224,
"resnet_v2_50": 224,
"resnet_v2_101": 224,
"resnet_v2_152": 224,
"mobilenet_v3_small_100_224": 224,
"mobilenet_v3_small_075_224": 224,
"mobilenet_v3_large_100_224": 224,
"mobilenet_v3_large_075_224": 224,
}
model_handle = model_handle_map[model_name]
print(f"Selected model: {model_name} : {model_handle}")
max_dynamic_size = 512
if model_name in model_image_size_map:
image_size = model_image_size_map[model_name]
dynamic_size = False
print(f"Images will be converted to {image_size}x{image_size}")
else:
dynamic_size = True
print(f"Images will be capped to a max size of {max_dynamic_size}x{max_dynamic_size}")
labels_file = "https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt"
#download labels and creates a maps
downloaded_file = tf.keras.utils.get_file("labels.txt", origin=labels_file)
classes = []
with open(downloaded_file) as f:
labels = f.readlines()
classes = [l.strip() for l in labels]
Selected model: efficientnetv2-s : https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet1k_s/classification/2 Images will be converted to 384x384 Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt 16384/10484 [==============================================] - 0s 0us/step 24576/10484 [======================================================================] - 0s 0us/step
Możesz wybrać jeden z poniższych obrazów lub użyć własnego obrazu. Pamiętaj tylko, że rozmiar wejściowy dla modeli jest różny, a niektóre z nich używają dynamicznego rozmiaru wejściowego (umożliwiającego wnioskowanie na nieskalowanym obrazie). Biorąc pod uwagę, że metoda load_image
będzie już przeskalować obraz do oczekiwanego formatu.
Wybierz obraz wejściowy
image_name = "turtle" # @param ['tiger', 'bus', 'car', 'cat', 'dog', 'apple', 'banana', 'turtle', 'flamingo', 'piano', 'honeycomb', 'teapot']
images_for_test_map = {
"tiger": "https://upload.wikimedia.org/wikipedia/commons/b/b0/Bengal_tiger_%28Panthera_tigris_tigris%29_female_3_crop.jpg",
#by Charles James Sharp, CC BY-SA 4.0 <https://creativecommons.org/licenses/by-sa/4.0>, via Wikimedia Commons
"bus": "https://upload.wikimedia.org/wikipedia/commons/6/63/LT_471_%28LTZ_1471%29_Arriva_London_New_Routemaster_%2819522859218%29.jpg",
#by Martin49 from London, England, CC BY 2.0 <https://creativecommons.org/licenses/by/2.0>, via Wikimedia Commons
"car": "https://upload.wikimedia.org/wikipedia/commons/4/49/2013-2016_Toyota_Corolla_%28ZRE172R%29_SX_sedan_%282018-09-17%29_01.jpg",
#by EurovisionNim, CC BY-SA 4.0 <https://creativecommons.org/licenses/by-sa/4.0>, via Wikimedia Commons
"cat": "https://upload.wikimedia.org/wikipedia/commons/4/4d/Cat_November_2010-1a.jpg",
#by Alvesgaspar, CC BY-SA 3.0 <https://creativecommons.org/licenses/by-sa/3.0>, via Wikimedia Commons
"dog": "https://upload.wikimedia.org/wikipedia/commons/archive/a/a9/20090914031557%21Saluki_dog_breed.jpg",
#by Craig Pemberton, CC BY-SA 3.0 <https://creativecommons.org/licenses/by-sa/3.0>, via Wikimedia Commons
"apple": "https://upload.wikimedia.org/wikipedia/commons/1/15/Red_Apple.jpg",
#by Abhijit Tembhekar from Mumbai, India, CC BY 2.0 <https://creativecommons.org/licenses/by/2.0>, via Wikimedia Commons
"banana": "https://upload.wikimedia.org/wikipedia/commons/1/1c/Bananas_white_background.jpg",
#by fir0002 flagstaffotos [at] gmail.com Canon 20D + Tamron 28-75mm f/2.8, GFDL 1.2 <http://www.gnu.org/licenses/old-licenses/fdl-1.2.html>, via Wikimedia Commons
"turtle": "https://upload.wikimedia.org/wikipedia/commons/8/80/Turtle_golfina_escobilla_oaxaca_mexico_claudio_giovenzana_2010.jpg",
#by Claudio Giovenzana, CC BY-SA 3.0 <https://creativecommons.org/licenses/by-sa/3.0>, via Wikimedia Commons
"flamingo": "https://upload.wikimedia.org/wikipedia/commons/b/b8/James_Flamingos_MC.jpg",
#by Christian Mehlführer, User:Chmehl, CC BY 3.0 <https://creativecommons.org/licenses/by/3.0>, via Wikimedia Commons
"piano": "https://upload.wikimedia.org/wikipedia/commons/d/da/Steinway_%26_Sons_upright_piano%2C_model_K-132%2C_manufactured_at_Steinway%27s_factory_in_Hamburg%2C_Germany.png",
#by "Photo: © Copyright Steinway & Sons", CC BY-SA 3.0 <https://creativecommons.org/licenses/by-sa/3.0>, via Wikimedia Commons
"honeycomb": "https://upload.wikimedia.org/wikipedia/commons/f/f7/Honey_comb.jpg",
#by Merdal, CC BY-SA 3.0 <http://creativecommons.org/licenses/by-sa/3.0/>, via Wikimedia Commons
"teapot": "https://upload.wikimedia.org/wikipedia/commons/4/44/Black_tea_pot_cropped.jpg",
#by Mendhak, CC BY-SA 2.0 <https://creativecommons.org/licenses/by-sa/2.0>, via Wikimedia Commons
}
img_url = images_for_test_map[image_name]
image, original_image = load_image(img_url, image_size, dynamic_size, max_dynamic_size)
show_image(image, 'Scaled image')
Teraz, gdy model został wybrany, ładowanie go za pomocą TensorFlow Hub jest proste.
To również wywołuje model z losowymi danymi wejściowymi jako przebieg „rozgrzewkowy”. Kolejne wywołania są często znacznie szybsze i możesz to porównać z opóźnieniem przedstawionym poniżej.
classifier = hub.load(model_handle)
input_shape = image.shape
warmup_input = tf.random.uniform(input_shape, 0, 1.0)
%time warmup_logits = classifier(warmup_input).numpy()
CPU times: user 2.88 s, sys: 470 ms, total: 3.35 s Wall time: 3.41 s
Wszystko jest gotowe do wnioskowania. Tutaj możesz zobaczyć 5 najlepszych wyników z modelu dla wybranego obrazu.
# Run model on image
%time probabilities = tf.nn.softmax(classifier(image)).numpy()
top_5 = tf.argsort(probabilities, axis=-1, direction="DESCENDING")[0][:5].numpy()
np_classes = np.array(classes)
# Some models include an additional 'background' class in the predictions, so
# we must account for this when reading the class labels.
includes_background_class = probabilities.shape[1] == 1001
for i, item in enumerate(top_5):
class_index = item if includes_background_class else item + 1
line = f'({i+1}) {class_index:4} - {classes[class_index]}: {probabilities[0][top_5][i]}'
print(line)
show_image(image, '')
CPU times: user 27.4 ms, sys: 9 µs, total: 27.4 ms Wall time: 25.9 ms (1) 35 - leatherback turtle: 0.7747752666473389 (2) 34 - loggerhead: 0.10644760727882385 (3) 37 - terrapin: 0.005874828901141882 (4) 148 - grey whale: 0.002594555728137493 (5) 36 - mud turtle: 0.0025599468499422073
Ucz się więcej
Jeśli chcesz dowiedzieć się więcej i spróbować, jak to zrobić transferu nauki z tych modeli można spróbować tego samouczka: Nauka transferowych dla klasyfikacji obrazu
Jeśli chcesz sprawdzić na więcej modeli graficznych można je sprawdzić na tfhub.dev