Segmentation des images

Voir sur TensorFlow.org Exécuter dans Google Colab Voir la source sur GitHub Télécharger le cahier

Ce tutoriel se concentre sur la tâche de segmentation d'image, en utilisant un U-Net modifié.

Qu'est-ce que la segmentation d'images ?

Dans une tâche de classification d'images, le réseau attribue une étiquette (ou classe) à chaque image d'entrée. Cependant, supposons que vous vouliez connaître la forme de cet objet, quel pixel appartient à quel objet, etc. Dans ce cas, vous voudrez attribuer une classe à chaque pixel de l'image. Cette tâche est connue sous le nom de segmentation. Un modèle de segmentation renvoie des informations beaucoup plus détaillées sur l'image. La segmentation d'images a de nombreuses applications dans l'imagerie médicale, les voitures autonomes et l'imagerie par satellite, pour n'en nommer que quelques-unes.

Ce tutoriel utilise l' ensemble de données Oxford-IIIT Pet ( Parkhi et al, 2012 ). L'ensemble de données se compose d'images de 37 races d'animaux de compagnie, avec 200 images par race (~ 100 chacune dans les divisions d'entraînement et de test). Chaque image comprend les étiquettes correspondantes et les masques pixel par pixel. Les masques sont des étiquettes de classe pour chaque pixel. Chaque pixel se voit attribuer l'une des trois catégories suivantes :

  • Classe 1 : Pixel appartenant à l'animal.
  • Classe 2 : Pixel bordant l'animal.
  • Classe 3 : aucune des réponses ci-dessus/un pixel environnant.
pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf

import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

Télécharger le jeu de données Oxford-IIIT Pets

L'ensemble de données est disponible à partir des ensembles de données TensorFlow . Les masques de segmentation sont inclus dans la version 3+.

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

De plus, les valeurs de couleur de l'image sont normalisées dans la plage [0,1] . Enfin, comme mentionné ci-dessus, les pixels du masque de segmentation sont étiquetés soit {1, 2, 3}. Par souci de commodité, soustrayez 1 du masque de segmentation, ce qui donne des étiquettes : {0, 1, 2}.

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(datapoint['segmentation_mask'], (128, 128))

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

L'ensemble de données contient déjà les fractionnements d'entraînement et de test requis, continuez donc à utiliser les mêmes fractionnements.

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

La classe suivante effectue une augmentation simple en retournant une image de manière aléatoire. Accédez au didacticiel d' augmentation d'image pour en savoir plus.

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

Construisez le pipeline d'entrée, en appliquant l'augmentation après avoir regroupé les entrées.

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

Visualisez un exemple d'image et son masque correspondant à partir du jeu de données.

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

png

png

2022-01-26 05:14:45.972101: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

Définir le modèle

Le modèle utilisé ici est un U-Net modifié. Un U-Net se compose d'un encodeur (downsampler) et d'un décodeur (upsampler). Afin d'apprendre des fonctionnalités robustes et de réduire le nombre de paramètres pouvant être formés, vous utiliserez un modèle pré-formé - MobileNetV2 - comme encodeur. Pour le décodeur, vous utiliserez le bloc upsample, qui est déjà implémenté dans l'exemple pix2pix du référentiel TensorFlow Examples. (Découvrez le pix2pix : Traduction d'image à image avec un didacticiel GAN ​​conditionnel dans un bloc-notes.)

Comme mentionné, l'encodeur sera un modèle MobileNetV2 pré-entraîné qui est préparé et prêt à être utilisé dans tf.keras.applications . L'encodeur se compose de sorties spécifiques des couches intermédiaires du modèle. Notez que l'encodeur ne sera pas formé pendant le processus de formation.

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

Le décodeur/suréchantillonneur est simplement une série de blocs de suréchantillonnage implémentés dans des exemples TensorFlow.

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]
def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

Notez que le nombre de filtres sur la dernière couche est défini sur le nombre de output_channels . Ce sera un canal de sortie par classe.

Former le modèle

Maintenant, il ne reste plus qu'à compiler et former le modèle.

Puisqu'il s'agit d'un problème de classification multiclasse, utilisez la fonction de perte tf.keras.losses.CategoricalCrossentropy avec l'argument from_logits défini sur True , car les étiquettes sont des entiers scalaires au lieu de vecteurs de scores pour chaque pixel de chaque classe.

Lors de l'exécution de l'inférence, l'étiquette attribuée au pixel est le canal avec la valeur la plus élevée. C'est ce que fait la fonction create_mask .

OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

Jetez un coup d'œil à l'architecture du modèle résultant :

tf.keras.utils.plot_model(model, show_shapes=True)

png

Essayez le modèle pour vérifier ce qu'il prédit avant l'entraînement.

def create_mask(pred_mask):
  pred_mask = tf.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()

png

Le rappel défini ci-dessous est utilisé pour observer comment le modèle s'améliore pendant sa formation.

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])

png

Sample Prediction after epoch 20

57/57 [==============================] - 4s 62ms/step - loss: 0.1838 - accuracy: 0.9187 - val_loss: 0.2797 - val_accuracy: 0.8955
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

png

Faire des prédictions

Maintenant, faites des prédictions. Afin de gagner du temps, le nombre d'époques a été réduit, mais vous pouvez le définir plus haut pour obtenir des résultats plus précis.

show_predictions(test_batches, 3)

png

png

png

Facultatif : Classes et pondérations de classe déséquilibrées

Les ensembles de données de segmentation sémantique peuvent être très déséquilibrés, ce qui signifie que des pixels de classe particuliers peuvent être plus présents à l'intérieur des images que dans d'autres classes. Étant donné que les problèmes de segmentation peuvent être traités comme des problèmes de classification par pixel, vous pouvez traiter le problème de déséquilibre en pondérant la fonction de perte pour en tenir compte. C'est une façon simple et élégante de régler ce problème. Consultez le didacticiel Classification sur données déséquilibrées pour en savoir plus.

Pour éviter toute ambiguïté , Model.fit ne prend pas en charge l'argument class_weight pour les entrées avec plus de 3 dimensions.

try:
  model_history = model.fit(train_batches, epochs=EPOCHS,
                            steps_per_epoch=STEPS_PER_EPOCH,
                            class_weight = {0:2.0, 1:2.0, 2:1.0})
  assert False
except Exception as e:
  print(f"Expected {type(e).__name__}: {e}")
Expected ValueError: `class_weight` not supported for 3+ dimensional targets.

Donc, dans ce cas, vous devez mettre en œuvre la pondération vous-même. Pour ce faire, vous utiliserez des exemples de poids : en plus des paires (data, label) , Model.fit accepte également les triplets (data, label, sample_weight) .

Model.fit propage le sample_weight aux pertes et aux métriques, qui acceptent également un argument sample_weight . Le poids de l'échantillon est multiplié par la valeur de l'échantillon avant l'étape de réduction. Par example:

label = [0,0]
prediction = [[-3., 0], [-3, 0]] 
sample_weight = [1, 10] 

loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True,
                                               reduction=tf.losses.Reduction.NONE)
loss(label, prediction, sample_weight).numpy()
array([ 3.0485873, 30.485874 ], dtype=float32)

Donc, pour créer des exemples de poids pour ce didacticiel, vous avez besoin d'une fonction qui prend une paire (data, label) et renvoie un (data, label, sample_weight) . Où sample_weight est une image à 1 canal contenant le poids de classe pour chaque pixel.

L'implémentation la plus simple possible consiste à utiliser l'étiquette comme index dans une liste class_weight :

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

Les éléments de l'ensemble de données résultant contiennent 3 images chacun :

train_batches.map(add_sample_weights).element_spec
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))

Vous pouvez maintenant entraîner un modèle sur cet ensemble de données pondéré :

weighted_model = unet_model(OUTPUT_CLASSES)
weighted_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
weighted_model.fit(
    train_batches.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
10/10 [==============================] - 3s 44ms/step - loss: 0.3099 - accuracy: 0.6063
<keras.callbacks.History at 0x7fa75d0f3e50>

Prochaines étapes

Maintenant que vous comprenez ce qu'est la segmentation d'image et comment elle fonctionne, vous pouvez essayer ce didacticiel avec différentes sorties de couche intermédiaire, ou même différents modèles pré-entraînés. Vous pouvez également vous mettre au défi en essayant le défi de masquage d'image Carvana hébergé sur Kaggle.

Vous pouvez également consulter l' API de détection d'objets Tensorflow pour un autre modèle que vous pouvez recycler sur vos propres données. Des modèles pré-entraînés sont disponibles sur TensorFlow Hub