Transferir el aprendizaje y la puesta a punto

Ver en TensorFlow.org Ejecutar en Google Colab Ver fuente en GitHub Descargar cuaderno

Configuración

import numpy as np
import tensorflow as tf
from tensorflow import keras

Introducción

Transferencia de aprendizaje consiste en tomar características aprendidas en un problema, y el aprovechamiento de ellos en un nuevo problema similar. Por ejemplo, las características de un modelo que ha aprendido a identificar mapaches pueden ser útiles para poner en marcha un modelo destinado a identificar tanukis.

El aprendizaje de transferencia generalmente se realiza para tareas en las que su conjunto de datos tiene muy pocos datos para entrenar un modelo a gran escala desde cero.

La encarnación más común del aprendizaje por transferencia en el contexto del aprendizaje profundo es el siguiente flujo de trabajo:

  1. Tome capas de un modelo previamente entrenado.
  2. Congélelos para evitar destruir la información que contienen durante las rondas de entrenamiento futuras.
  3. Agregue algunas capas nuevas y entrenables sobre las capas congeladas. Aprenderán a convertir las características antiguas en predicciones en un nuevo conjunto de datos.
  4. Entrene las nuevas capas en su conjunto de datos.

Un último paso, opcional, es puesta a punto, que consiste en descongelar todo el modelo haya obtenido anteriormente (o parte de ella), y volver a entrenar en los nuevos datos con una velocidad de aprendizaje muy baja. Potencialmente, esto puede lograr mejoras significativas al adaptar gradualmente las funciones previamente entrenadas a los nuevos datos.

En primer lugar, vamos a ir sobre la Keras trainable API en detalle, que subyace en la mayoría de los flujos de trabajo de aprendizaje y transferencia de ajuste.

Luego, demostraremos el flujo de trabajo típico al tomar un modelo previamente entrenado en el conjunto de datos de ImageNet y volver a entrenarlo en el conjunto de datos de clasificación "gatos contra perros" de Kaggle.

Esta es una adaptación de profundo aprendizaje con Python y la entrada del blog 2016 "la construcción de potentes modelos de clasificación de imágenes utilizando muy pocos datos" .

Capas congeladas: Comprensión de la trainable atributo

Las capas y los modelos tienen tres atributos de peso:

  • weights es la lista de todas las variables de pesos de la capa.
  • trainable_weights es la lista de aquellos que están destinados a ser actualizado (a través de descenso de gradiente) para minimizar la pérdida durante el entrenamiento.
  • non_trainable_weights es la lista de los que no están destinados a ser entrenado. Por lo general, el modelo las actualiza durante el pase hacia adelante.

Ejemplo: la Dense capa tiene 2 pesos entrenables (kernel y sesgo)

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 2
non_trainable_weights: 0

En general, todos los pesos son pesos entrenables. El único incorporado en la capa que tiene los pesos no entrenables es el BatchNormalization capa. Utiliza pesos no entrenables para realizar un seguimiento de la media y la varianza de sus entradas durante el entrenamiento. Para aprender a usar los pesos no entrenables en sus propias capas personalizadas, consulte la guía para escribir nuevas capas a partir de cero .

Ejemplo: el BatchNormalization capa tiene 2 pesos entrenables y 2 pesos no entrenables

layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4
trainable_weights: 2
non_trainable_weights: 2

Capas y modelos también cuentan con un atributo booleano trainable . Su valor se puede cambiar. Configuración layer.trainable a False mueve todos los pesos de la capa de entrenable a no entrenable. Esto se llama "congelación" de la capa: el estado de una capa congelada no se actualizará durante el entrenamiento (ya sea cuando el entrenamiento con el fit() o cuando el entrenamiento con cualquier lazo personalizado que se basa en trainable_weights para aplicar las actualizaciones de gradiente).

Ejemplo: ajuste trainable a False

layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2
trainable_weights: 0
non_trainable_weights: 2

Cuando un peso entrenable se vuelve no entrenable, su valor ya no se actualiza durante el entrenamiento.

# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945

No hay que confundir la layer.trainable atributo con el argumento de training de layer.__call__() (que controla si la capa debe ejecutar su paso hacia adelante en el modo de inferencia o el modo de entrenamiento). Para obtener más información, consulte la Keras FAQ .

Configuración recursiva de la trainable atributo

Si configura trainable = False en un modelo o en cualquier capa que tiene subcapas, todos los niños se convierten en capas no entrenable también.

Ejemplo:

inner_model = keras.Sequential(
    [
        keras.Input(shape=(3,)),
        keras.layers.Dense(3, activation="relu"),
        keras.layers.Dense(3, activation="relu"),
    ]
)

model = keras.Sequential(
    [keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)

model.trainable = False  # Freeze the outer model

assert inner_model.trainable == False  # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively

El típico flujo de trabajo de aprendizaje por transferencia

Esto nos lleva a cómo se puede implementar un flujo de trabajo de aprendizaje por transferencia típico en Keras:

  1. Cree una instancia de un modelo base y cargue pesos previamente entrenados en él.
  2. Congelación de todas las capas del modelo de base mediante el establecimiento de trainable = False .
  3. Cree un nuevo modelo sobre la salida de una (o varias) capas del modelo base.
  4. Entrena tu nuevo modelo en tu nuevo conjunto de datos.

Tenga en cuenta que un flujo de trabajo alternativo y más ligero también podría ser:

  1. Cree una instancia de un modelo base y cargue pesos previamente entrenados en él.
  2. Ejecute su nuevo conjunto de datos a través de él y registre la salida de una (o varias) capas del modelo base. Esto se conoce como extracción de características.
  3. Utilice esa salida como datos de entrada para un modelo nuevo y más pequeño.

Una ventaja clave de ese segundo flujo de trabajo es que solo ejecuta el modelo base una vez en sus datos, en lugar de una vez por época de entrenamiento. Así que es mucho más rápido y más barato.

Sin embargo, un problema con ese segundo flujo de trabajo es que no le permite modificar dinámicamente los datos de entrada de su nuevo modelo durante el entrenamiento, lo cual es necesario cuando se realiza un aumento de datos, por ejemplo. El aprendizaje de transferencia generalmente se usa para tareas cuando su nuevo conjunto de datos tiene muy pocos datos para entrenar un modelo a gran escala desde cero, y en tales escenarios, el aumento de datos es muy importante. Entonces, en lo que sigue, nos centraremos en el primer flujo de trabajo.

Así es como se ve el primer flujo de trabajo en Keras:

Primero, cree una instancia de un modelo base con pesos previamente entrenados.

base_model = keras.applications.Xception(
    weights='imagenet',  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False)  # Do not include the ImageNet classifier at the top.

Luego, congela el modelo base.

base_model.trainable = False

Crea un nuevo modelo en la parte superior.

inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

Entrene el modelo con nuevos datos.

model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)

Sintonia FINA

Una vez que su modelo haya convergido en los nuevos datos, puede intentar descongelar todo o parte del modelo base y volver a entrenar todo el modelo de principio a fin con una tasa de aprendizaje muy baja.

Este es un último paso opcional que potencialmente puede brindarle mejoras incrementales. También podría conducir a un sobreajuste rápido, tenlo en cuenta.

Es crítico sólo para hacer este paso después de que el modelo con capas congeladas ha sido entrenado para la convergencia. Si mezcla capas entrenables inicializadas aleatoriamente con capas entrenables que contienen funciones preentrenadas, las capas inicializadas aleatoriamente generarán actualizaciones de gradiente muy grandes durante el entrenamiento, lo que destruirá sus funciones preentrenadas.

También es fundamental usar una tasa de aprendizaje muy baja en esta etapa, porque está entrenando un modelo mucho más grande que en la primera ronda de entrenamiento, en un conjunto de datos que suele ser muy pequeño. Como resultado, corre el riesgo de sobreajustar muy rápidamente si aplica grandes actualizaciones de peso. Aquí, solo desea readaptar los pesos preentrenados de manera incremental.

Así es como se implementa el ajuste fino de todo el modelo base:

# Unfreeze the base model
base_model.trainable = True

# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate
              loss=keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=[keras.metrics.BinaryAccuracy()])

# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)

Nota importante acerca compile() y trainable

Llamando compile() en un modelo que se entiende a "congelar" el comportamiento de ese modelo. Esto implica que los trainable valores de atributos en el momento se compila el modelo debe ser preservado a lo largo del curso de la vida de ese modelo, hasta que compile se llama de nuevo. Por lo tanto, si cambia cualquier trainable valor, asegúrese de llamada compile() de nuevo en su modelo para los cambios que deben tenerse en cuenta.

Notas importantes sobre BatchNormalization capa

Muchos modelos de imagen contienen BatchNormalization capas. Esa capa es un caso especial en todos los aspectos imaginables. Aquí hay algunas cosas a tener en cuenta.

  • BatchNormalization contiene 2 los pesos no entrenables que se actualizan durante el entrenamiento. Estas son las variables que siguen la media y la varianza de las entradas.
  • Cuando se establece bn_layer.trainable = False , el BatchNormalization capa se ejecutará en modo de inferencia, y no actualizará sus estadísticas medias y varianzas. Este no es el caso para otras capas en general, como el peso y la capacidad de formación de inferencia / modos de entrenamiento son dos conceptos ortogonales . Pero los dos están vinculados en el caso de la BatchNormalization capa.
  • Al descongelar un modelo que contiene BatchNormalization capas con el fin de hacer el ajuste fino, se debe mantener el BatchNormalization capas en modo de inferencia mediante el paso training=False cuando se llama el modelo base. De lo contrario, las actualizaciones aplicadas a los pesos no entrenables destruirán repentinamente lo que ha aprendido el modelo.

Verá este patrón en acción en el ejemplo de extremo a extremo al final de esta guía.

Transfiera el aprendizaje y ajuste con un ciclo de entrenamiento personalizado

Si en lugar de fit() , está utilizando su propio bucle de entrenamiento bajo nivel, las estancias de flujo de trabajo esencialmente el mismo. Debe tener cuidado de tomar sólo en cuenta la lista model.trainable_weights al aplicar cambios de gradiente:

# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

Del mismo modo para el ajuste fino.

Un ejemplo de principio a fin: ajuste de un modelo de clasificación de imágenes en un conjunto de datos de perros y gatos

Para solidificar estos conceptos, lo guiaremos a través de un ejemplo concreto de ajuste y aprendizaje de transferencia de un extremo a otro. Cargaremos el modelo Xception, previamente entrenado en ImageNet, y lo usaremos en el conjunto de datos de clasificación "gatos contra perros" de Kaggle.

Obteniendo los datos

Primero, obtengamos el conjunto de datos de gatos contra perros usando TFDS. Si usted tiene su propio conjunto de datos, es probable que desee utilizar la utilidad tf.keras.preprocessing.image_dataset_from_directory para generar el conjunto de datos etiquetados similares objetos de un conjunto de imágenes en el disco presentadas en carpetas específicas de clase.

El aprendizaje de transferencia es más útil cuando se trabaja con conjuntos de datos muy pequeños. Para mantener nuestro conjunto de datos pequeño, utilizaremos el 40 % de los datos de entrenamiento originales (25 000 imágenes) para el entrenamiento, el 10 % para la validación y el 10 % para las pruebas.

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

train_ds, validation_ds, test_ds = tfds.load(
    "cats_vs_dogs",
    # Reserve 10% for validation and 10% for test
    split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
    as_supervised=True,  # Include labels
)

print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
    "Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305
Number of validation samples: 2326
Number of test samples: 2326

Estas son las primeras 9 imágenes del conjunto de datos de entrenamiento; como puede ver, todas son de diferentes tamaños.

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(image)
    plt.title(int(label))
    plt.axis("off")

png

También podemos ver que la etiqueta 1 es "perro" y la etiqueta 0 es "gato".

Estandarizando los datos

Nuestras imágenes en bruto tienen una variedad de tamaños. Además, cada píxel consta de 3 valores enteros entre 0 y 255 (valores de nivel RGB). Esto no es muy adecuado para alimentar una red neuronal. Tenemos que hacer 2 cosas:

  • Estandarizar a un tamaño de imagen fijo. Elegimos 150x150.
  • Normalizar los valores de píxeles entre -1 y 1. Haremos esto utilizando una Normalization capa como parte del propio modelo.

En general, es una buena práctica desarrollar modelos que toman datos sin procesar como entrada, a diferencia de modelos que toman datos ya preprocesados. El motivo es que, si su modelo espera datos preprocesados, cada vez que exporte su modelo para usarlo en otro lugar (en un navegador web, en una aplicación móvil), deberá volver a implementar exactamente la misma canalización de preprocesamiento. Esto se vuelve muy complicado muy rápidamente. Por tanto, deberíamos hacer la menor cantidad posible de preprocesamiento antes de utilizar el modelo.

Aquí, cambiaremos el tamaño de la imagen en la canalización de datos (porque una red neuronal profunda solo puede procesar lotes de datos contiguos) y escalaremos el valor de entrada como parte del modelo, cuando lo creemos.

Redimensionemos las imágenes a 150x150:

size = (150, 150)

train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))

Además, vamos a agrupar los datos y utilizar el almacenamiento en caché y la captación previa para optimizar la velocidad de carga.

batch_size = 32

train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)

Uso de aumento de datos aleatorios

Cuando no tiene un conjunto de datos de imágenes grande, es una buena práctica introducir artificialmente diversidad de muestras aplicando transformaciones aleatorias pero realistas a las imágenes de entrenamiento, como cambios horizontales aleatorios o pequeñas rotaciones aleatorias. Esto ayuda a exponer el modelo a diferentes aspectos de los datos de entrenamiento mientras ralentiza el sobreajuste.

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential(
    [layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)

Visualicemos cómo se ve la primera imagen del primer lote después de varias transformaciones aleatorias:

import numpy as np

for images, labels in train_ds.take(1):
    plt.figure(figsize=(10, 10))
    first_image = images[0]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(
            tf.expand_dims(first_image, 0), training=True
        )
        plt.imshow(augmented_image[0].numpy().astype("int32"))
        plt.title(int(labels[0]))
        plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

Construir un modelo

Ahora construyamos un modelo que siga el modelo que hemos explicado anteriormente.

Tenga en cuenta que:

  • Añadimos un Rescaling capa a valores de entrada escala (inicialmente en el [0, 255] gama) a la [-1, 1] gama.
  • Añadimos una Dropout capa antes de la capa de clasificación, para la regularización.
  • Nos aseguramos de pasar training=False cuando se llama el modelo base, para que se ejecute en modo de inferencia, por lo que las estadísticas batchnorm no se actualizan, incluso después de descongelar el modelo base para la puesta a punto.
base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation

# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)

# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5
83689472/83683744 [==============================] - 2s 0us/step
83697664/83683744 [==============================] - 2s 0us/step
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 2,049
Non-trainable params: 20,861,480
_________________________________________________________________

Entrena la capa superior

model.compile(
    optimizer=keras.optimizers.Adam(),
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20
151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096
Corrupt JPEG data: 65 extraneous bytes before marker 0xd9
268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269
Corrupt JPEG data: 239 extraneous bytes before marker 0xd9
282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284
Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9
Corrupt JPEG data: 228 extraneous bytes before marker 0xd9
291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286
Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9
291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686
Epoch 2/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695
Epoch 3/20
291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712
Epoch 4/20
291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703
Epoch 5/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725
Epoch 6/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699
Epoch 7/20
291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716
Epoch 8/20
291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678
Epoch 9/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729
Epoch 10/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721
Epoch 11/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725
Epoch 12/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716
Epoch 13/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695
Epoch 14/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712
Epoch 15/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712
Epoch 16/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699
Epoch 17/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708
Epoch 18/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716
Epoch 19/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716
Epoch 20/20
291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695
<keras.callbacks.History at 0x7f849a3b2950>

Haz una ronda de puesta a punto de todo el modelo.

Finalmente, descongelemos el modelo base y entrenemos todo el modelo de un extremo a otro con una tasa de aprendizaje baja.

Es importante destacar que, aunque el modelo base se convierte en entrenable, que todavía se está ejecutando en el modo de inferencia desde que pasamos training=False al llamar cuando se construyó el modelo. Esto significa que las capas de normalización de lotes que se encuentran en el interior no actualizarán sus estadísticas de lotes. Si lo hicieran, causarían estragos en las representaciones aprendidas por el modelo hasta ahora.

# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()

model.compile(
    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate
    loss=keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=[keras.metrics.BinaryAccuracy()],
)

epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 150, 150, 3)]     0         
_________________________________________________________________
sequential_3 (Sequential)    (None, 150, 150, 3)       0         
_________________________________________________________________
rescaling (Rescaling)        (None, 150, 150, 3)       0         
_________________________________________________________________
xception (Functional)        (None, 5, 5, 2048)        20861480  
_________________________________________________________________
global_average_pooling2d (Gl (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense_7 (Dense)              (None, 1)                 2049      
=================================================================
Total params: 20,863,529
Trainable params: 20,809,001
Non-trainable params: 54,528
_________________________________________________________________
Epoch 1/10
291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764
Epoch 2/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764
Epoch 3/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798
Epoch 4/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819
Epoch 5/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807
Epoch 6/10
291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824
Epoch 7/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802
Epoch 8/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819
Epoch 9/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837
Epoch 10/10
291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832
<keras.callbacks.History at 0x7f83982d4cd0>

Después de 10 épocas, el ajuste fino nos da una buena mejora aquí.