Zobacz na TensorFlow.org | Uruchom w Google Colab | Wyświetl źródło na GitHub | Pobierz notatnik |
Ustawiać
import numpy as np
import tensorflow as tf
from tensorflow import keras
Wstęp
Learning transferu polega na pobraniu funkcje wyuczone na jeden problem, i wykorzystanie ich w nowym, podobnym problemem. Na przykład cechy modelu, który nauczył się rozpoznawać szopy pracze, mogą być przydatne do uruchomienia modelu przeznaczonego do identyfikacji tanuki.
Nauka transferu jest zwykle wykonywana w przypadku zadań, w których zestaw danych zawiera zbyt mało danych, aby wytrenować model w pełnej skali od podstaw.
Najczęstszym wcieleniem transfer learning w kontekście deep learningu jest następujący przepływ pracy:
- Pobierz warstwy z wcześniej wytrenowanego modelu.
- Zamroź je, aby podczas przyszłych rund treningowych nie zniszczyć jakichkolwiek zawartych w nich informacji.
- Dodaj kilka nowych, nadających się do trenowania warstw na wierzchu zamrożonych warstw. Nauczą się przekształcać stare funkcje w prognozy na nowym zbiorze danych.
- Trenuj nowe warstwy w swoim zbiorze danych.
Ostatnim, opcjonalny krok, to dostrajanie, który składa się z odmrożenie cały model uzyskany powyżej (lub jego część) i przekwalifikowanie go na nowych danych z bardzo małą szybkością uczenia się. Może to potencjalnie osiągnąć znaczące ulepszenia, stopniowo dostosowując wstępnie wytrenowane funkcje do nowych danych.
Najpierw pojedziemy nad Keras trainable
API w szczegółach, które leży u podstaw większości uczenia Transfer & dostrajających przepływów pracy.
Następnie zademonstrujemy typowy przepływ pracy, biorąc model wstępnie przeszkolony w zestawie danych ImageNet i przeszkolając go ponownie w zestawie danych klasyfikacji „koty kontra psy” Kaggle.
To jest adaptacją głębokie nauki z Python i 2016 blogu „budowanie potężnych modeli klasyfikacyjnych obraz za pomocą bardzo mało danych” .
Zamrażanie warstw: Zrozumienie trainable
atrybut
Warstwy i modele mają trzy atrybuty wagi:
-
weights
lista wszystkich wag zmiennych warstwy. -
trainable_weights
lista tych, które mają być aktualizowane (poprzez zejście gradientu), aby zminimalizować straty podczas treningu. -
non_trainable_weights
lista tych, które nie mają być przeszkoleni. Zazwyczaj są one aktualizowane przez model podczas przejścia do przodu.
Przykład: Dense
warstwa ma 2 nadającego masy (jądro i polaryzacji)
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 2 non_trainable_weights: 0
Ogólnie rzecz biorąc, wszystkie ciężary są ciężarami, które można trenować. Jedyny wbudowany w warstwę, która ma zakaz wyszkolić ciężarów jest BatchNormalization
warstwa. Wykorzystuje ciężary, których nie można trenować, aby śledzić średnią i wariancję danych wejściowych podczas treningu. Aby dowiedzieć się, jak używać non-wyszkolić ciężarów we własnych niestandardowych warstw, zobacz przewodnik pisanie nowych warstw od zera .
Przykład: BatchNormalization
warstwa ma 2 nadającego wagi i 2 nie nadającego wag
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4 trainable_weights: 2 non_trainable_weights: 2
Warstwy i modele są również wyposażone w logiczną atrybutu trainable
. Jego wartość można zmienić. Ustawianie layer.trainable
do False
ruchów wszystkie ciężary warstwa jest z wyszkolić do nieprzestrzegania wyszkolić. Nazywa się to „zamrożenie” warstwa: stan zamarzniętej warstwy nie będą aktualizowane w trakcie szkolenia (zarówno podczas szkolenia z fit()
lub gdy szkolenie z dowolnej niestandardowej pętli, która opiera się na trainable_weights
zastosować aktualizacje gradient).
Przykład: ustawiania trainable
do False
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 0 non_trainable_weights: 2
Kiedy waga możliwa do trenowania staje się niemożliwa do wytrenowania, jej wartość nie jest już aktualizowana podczas treningu.
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945
Nie należy mylić layer.trainable
atrybut z argumentem training
w layer.__call__()
(która określa, czy warstwa powinna prowadzić swoje podaniu w trybie wnioskowania lub trybu treningowego). Aby uzyskać więcej informacji, zobacz Keras nas .
Rekurencyjne ustawienie trainable
atrybutu
Jeśli ustawisz trainable = False
na modelu lub na dowolnej warstwy, która ma podwarstwy, wszystkie dzieci warstwy stać non-wyszkolić również.
Przykład:
inner_model = keras.Sequential(
[
keras.Input(shape=(3,)),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(3, activation="relu"),
]
)
model = keras.Sequential(
[keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)
model.trainable = False # Freeze the outer model
assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
Typowy przepływ pracy typu transfer-learning
To prowadzi nas do tego, jak typowy przepływ uczenia się transferowego można wdrożyć w Keras:
- Utwórz wystąpienie modelu podstawowego i załaduj do niego wstępnie wytrenowane wagi.
- Zamrozić wszystkie warstwy w modelu bazowym poprzez ustawienie
trainable = False
. - Utwórz nowy model na podstawie danych wyjściowych jednej (lub kilku) warstw z modelu podstawowego.
- Wytrenuj nowy model na nowym zbiorze danych.
Zwróć uwagę, że alternatywnym, lżejszym przepływem pracy może być również:
- Utwórz wystąpienie modelu podstawowego i załaduj do niego wstępnie wytrenowane wagi.
- Przeprowadź przez niego nowy zestaw danych i zapisz dane wyjściowe jednej (lub kilku) warstw z modelu podstawowego. Jest to tak zwana funkcja ekstrakcji.
- Użyj tych danych wyjściowych jako danych wejściowych dla nowego, mniejszego modelu.
Kluczową zaletą tego drugiego przepływu pracy jest to, że model podstawowy jest uruchamiany tylko raz na danych, a nie raz na epokę uczenia. Więc jest o wiele szybciej i taniej.
Problem z tym drugim przepływem pracy polega jednak na tym, że nie pozwala on na dynamiczną modyfikację danych wejściowych nowego modelu podczas uczenia, co jest wymagane na przykład podczas rozszerzania danych. Uczenie się przenoszenia jest zwykle używane w przypadku zadań, w których nowy zestaw danych zawiera zbyt mało danych, aby można było wytrenować model w pełnej skali od podstaw, a w takich scenariuszach bardzo ważne jest rozszerzanie danych. W dalszej części skupimy się na pierwszym przepływie pracy.
Oto jak wygląda pierwszy przepływ pracy w Keras:
Najpierw utwórz wystąpienie modelu podstawowego ze wstępnie wytrenowanymi wagami.
base_model = keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
Następnie zamroź model podstawowy.
base_model.trainable = False
Utwórz nowy model na górze.
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
Trenuj model na nowych danych.
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
Strojenie
Gdy model osiągnie zbieżność na nowych danych, możesz spróbować odblokować całość lub część modelu podstawowego i przeszkolić cały model od początku do końca z bardzo niskim współczynnikiem uczenia.
Jest to opcjonalny ostatni krok, który może potencjalnie zapewnić stopniową poprawę. Może to również potencjalnie prowadzić do szybkiego overfittingu – miej to na uwadze.
Bardzo ważne jest, aby tylko zrobić ten krok po model z zamrożonych warstw został przeszkolony do konwergencji. Jeśli zmieszasz losowo inicjowane warstwy możliwe do trenowania z warstwami możliwymi do trenowania, które zawierają wstępnie wytrenowane funkcje, losowo zainicjowane warstwy spowodują bardzo duże aktualizacje gradientu podczas treningu, co zniszczy wstępnie wytrenowane funkcje.
Bardzo ważne jest również użycie bardzo niskiego współczynnika uczenia się na tym etapie, ponieważ trenujesz znacznie większy model niż w pierwszej rundzie uczenia, na zestawie danych, który jest zwykle bardzo mały. W rezultacie istnieje ryzyko bardzo szybkiego przeciążenia, jeśli zastosujesz duże aktualizacje wagi. Tutaj chcesz tylko dostosować wstępnie wytrenowane wagi w sposób przyrostowy.
Oto jak zaimplementować dostrojenie całego modelu podstawowego:
# Unfreeze the base model
base_model.trainable = True
# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
Ważna uwaga o compile()
i trainable
Wywoływanie compile()
na modelu rozumie się „zamrożenia” zachowanie tego modelu. Oznacza to, że trainable
wartości atrybutów w czasie model jest kompilowany powinny być zachowane przez cały okres użytkowania tego modelu, aż do compile
nazywa się ponownie. Stąd, jeśli zmienić dowolny trainable
wartość, upewnij się, aby zadzwonić do compile()
ponownie w modelu na zmiany mają być brane pod uwagę.
Ważne informacje o BatchNormalization
warstwie
Wiele modeli graficznych zawierają BatchNormalization
warstw. Ta warstwa jest szczególnym przypadkiem pod każdym możliwym względem. Oto kilka rzeczy, o których należy pamiętać.
-
BatchNormalization
zawiera 2 non-wyszkolić ciężary, które aktualizowane w czasie treningu. Są to zmienne śledzące średnią i wariancję danych wejściowych. - Po ustawieniu
bn_layer.trainable = False
TheBatchNormalization
warstwa będzie działać w trybie wnioskowania i nie aktualizuje swoich średnich i wariancji statystyk. To nie jest sprawa dla innych warstw Ogólnie, jak waga trainability & wnioskowania tryby szkolenia / są dwie prostopadłe koncepcje . Ale dwa są związane w przypadkuBatchNormalization
warstwy. - Kiedy odmrozić model, który zawiera
BatchNormalization
warstw w tym celu dostrajania, należy zachowaćBatchNormalization
warstwy w trybie wnioskowania o przejściutraining=False
podczas wywoływania modelu bazowego. W przeciwnym razie aktualizacje zastosowane do wag, których nie można wyszkolić, nagle zniszczą to, czego nauczył się model.
Zobaczysz ten wzorzec w akcji w kompletnym przykładzie na końcu tego przewodnika.
Przenieś naukę i dostrajanie za pomocą niestandardowej pętli treningowej
Jeśli zamiast fit()
, używasz własną pętlę szkoleniowy niskim poziomie, pobyty workflow w zasadzie takie same. Należy uważać, aby wziąć pod uwagę tylko listy model.trainable_weights
podczas stosowania aktualizacji gradient:
# Create base model
base_model = keras.applications.Xception(
weights='imagenet',
input_shape=(150, 150, 3),
include_top=False)
# Freeze base model
base_model.trainable = False
# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()
# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
# Open a GradientTape.
with tf.GradientTape() as tape:
# Forward pass.
predictions = model(inputs)
# Compute the loss value for this batch.
loss_value = loss_fn(targets, predictions)
# Get gradients of loss wrt the *trainable* weights.
gradients = tape.gradient(loss_value, model.trainable_weights)
# Update the weights of the model.
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
Podobnie do dostrajania.
Kompleksowy przykład: dopracowanie modelu klasyfikacji obrazów w zestawie danych koty i psy
Aby utrwalić te koncepcje, przeprowadźmy Cię przez konkretny przykład kompleksowego uczenia się i dostrajania. Załadujemy model Xception, wstępnie wytrenowany w ImageNet, i użyjemy go w zestawie danych klasyfikacji „koty kontra psy” Kaggle.
Uzyskiwanie danych
Najpierw pobierzmy zestaw danych koty kontra psy za pomocą TFDS. Jeśli masz swój własny zestaw danych, prawdopodobnie będziesz chciał użyć narzędzia tf.keras.preprocessing.image_dataset_from_directory
generować podobne obiekty oznaczone zestaw danych ze zbioru obrazów na dysku złożone w foldery klasy specyficzne.
Uczenie się transferu jest najbardziej przydatne podczas pracy z bardzo małymi zestawami danych. Aby nasz zestaw danych był niewielki, użyjemy 40% oryginalnych danych treningowych (25 000 obrazów) do trenowania, 10% do walidacji i 10% do testowania.
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
"Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305 Number of validation samples: 2326 Number of test samples: 2326
Oto pierwsze 9 obrazów w treningowym zbiorze danych — jak widać, wszystkie mają różne rozmiary.
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
Widzimy również, że etykieta 1 to „pies”, a etykieta 0 to „kot”.
Standaryzacja danych
Nasze surowe obrazy mają różne rozmiary. Ponadto każdy piksel składa się z 3 wartości całkowitych z zakresu od 0 do 255 (wartości na poziomie RGB). Nie jest to idealne rozwiązanie do zasilania sieci neuronowej. Musimy zrobić 2 rzeczy:
- Standaryzuj do stałego rozmiaru obrazu. Wybieramy 150x150.
- Znormalizować wartości pikseli pomiędzy -1 a 1. Będziemy to robić za pomocą
Normalization
warstwę jako część samego modelu.
Ogólnie rzecz biorąc, dobrą praktyką jest tworzenie modeli, które pobierają nieprzetworzone dane jako dane wejściowe, w przeciwieństwie do modeli, które przyjmują już wstępnie przetworzone dane. Powodem jest to, że jeśli model oczekuje wstępnie przetworzonych danych, za każdym razem, gdy eksportujesz model, aby użyć go w innym miejscu (w przeglądarce internetowej, w aplikacji mobilnej), będziesz musiał ponownie zaimplementować dokładnie ten sam potok przetwarzania wstępnego. To bardzo szybko staje się trudne. Powinniśmy więc wykonać najmniejszą możliwą ilość wstępnego przetwarzania przed uderzeniem w model.
Tutaj dokonamy zmiany rozmiaru obrazu w potoku danych (ponieważ głęboka sieć neuronowa może przetwarzać tylko ciągłe partie danych) i wykonamy skalowanie wartości wejściowej jako część modelu podczas jego tworzenia.
Zmieńmy rozmiar obrazów na 150x150:
size = (150, 150)
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
Poza tym zbierzmy dane i użyjmy buforowania i pobierania z wyprzedzeniem, aby zoptymalizować prędkość ładowania.
batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)
Korzystanie z losowego zwiększania danych
Jeśli nie masz dużego zestawu danych obrazu, dobrą praktyką jest sztuczne wprowadzenie różnorodności próbek przez zastosowanie losowych, ale realistycznych przekształceń do obrazów szkoleniowych, takich jak losowe odwracanie w poziomie lub małe losowe obroty. Pomaga to wystawić model na różne aspekty danych uczących, jednocześnie spowalniając nadmierne dopasowanie.
from tensorflow import keras
from tensorflow.keras import layers
data_augmentation = keras.Sequential(
[layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)
Wyobraźmy sobie, jak wygląda pierwszy obraz pierwszej partii po różnych losowych przekształceniach:
import numpy as np
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(
tf.expand_dims(first_image, 0), training=True
)
plt.imshow(augmented_image[0].numpy().astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Zbudować model
Teraz zbudujmy model zgodny z planem, który wyjaśniliśmy wcześniej.
Zwróć uwagę, że:
- Dodajmy do
Rescaling
warstwy do wartości wejściowych skalę (początkowo w[0, 255]
zakresu) w[-1, 1]
zakresu. - Dodamy
Dropout
warstwy przed nałożeniem warstwy klasyfikacji, dla uregulowania. - Dbamy o to, aby przejść
training=False
podczas wywoływania modelu bazowego, tak, że działa w trybie wnioskowania, dzięki czemu statystyki batchnorm nie aktualizowane nawet po odmrozić model bazowy dostrajających.
base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs) # Apply random data augmentation
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5 83689472/83683744 [==============================] - 2s 0us/step 83697664/83683744 [==============================] - 2s 0us/step Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 150, 150, 3)] 0 _________________________________________________________________ sequential_3 (Sequential) (None, 150, 150, 3) 0 _________________________________________________________________ rescaling (Rescaling) (None, 150, 150, 3) 0 _________________________________________________________________ xception (Functional) (None, 5, 5, 2048) 20861480 _________________________________________________________________ global_average_pooling2d (Gl (None, 2048) 0 _________________________________________________________________ dropout (Dropout) (None, 2048) 0 _________________________________________________________________ dense_7 (Dense) (None, 1) 2049 ================================================================= Total params: 20,863,529 Trainable params: 2,049 Non-trainable params: 20,861,480 _________________________________________________________________
Trenuj górną warstwę
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20 151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096 Corrupt JPEG data: 65 extraneous bytes before marker 0xd9 268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269 Corrupt JPEG data: 239 extraneous bytes before marker 0xd9 282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284 Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9 Corrupt JPEG data: 228 extraneous bytes before marker 0xd9 291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286 Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9 291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686 Epoch 2/20 291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695 Epoch 3/20 291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712 Epoch 4/20 291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703 Epoch 5/20 291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725 Epoch 6/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699 Epoch 7/20 291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716 Epoch 8/20 291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678 Epoch 9/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729 Epoch 10/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721 Epoch 11/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725 Epoch 12/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716 Epoch 13/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695 Epoch 14/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712 Epoch 15/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712 Epoch 16/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699 Epoch 17/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708 Epoch 18/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716 Epoch 19/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716 Epoch 20/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695 <keras.callbacks.History at 0x7f849a3b2950>
Wykonaj rundę dostrajania całego modelu
Na koniec odblokujmy model podstawowy i wytrenujmy cały model od początku do końca z niskim współczynnikiem uczenia się.
Co ważne, mimo że model podstawowy staje się wyszkolić, to nadal działa w trybie wnioskowania odkąd przeszedł training=False
, gdy dzwoni, gdy zbudowaliśmy model. Oznacza to, że znajdujące się wewnątrz warstwy normalizacji wsadowej nie będą aktualizować swoich statystyk wsadowych. Gdyby to zrobili, zniszczyliby reprezentacje poznane przez model do tej pory.
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 150, 150, 3)] 0 _________________________________________________________________ sequential_3 (Sequential) (None, 150, 150, 3) 0 _________________________________________________________________ rescaling (Rescaling) (None, 150, 150, 3) 0 _________________________________________________________________ xception (Functional) (None, 5, 5, 2048) 20861480 _________________________________________________________________ global_average_pooling2d (Gl (None, 2048) 0 _________________________________________________________________ dropout (Dropout) (None, 2048) 0 _________________________________________________________________ dense_7 (Dense) (None, 1) 2049 ================================================================= Total params: 20,863,529 Trainable params: 20,809,001 Non-trainable params: 54,528 _________________________________________________________________ Epoch 1/10 291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764 Epoch 2/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764 Epoch 3/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798 Epoch 4/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819 Epoch 5/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807 Epoch 6/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824 Epoch 7/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802 Epoch 8/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819 Epoch 9/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837 Epoch 10/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832 <keras.callbacks.History at 0x7f83982d4cd0>
Po 10 epokach dostrajanie daje nam tutaj niezłą poprawę.