CropNet : Détection des maladies du manioc

Voir sur TensorFlow.org Exécuter dans Google Colab Voir sur GitHub Télécharger le cahier Voir le modèle TF Hub

Ce bloc - notes montre comment utiliser le CropNet classificateur de la maladie du manioc modèle de tensorflow Hub. Le modèle classifie les images de feuilles de manioc dans l' une des 6 classes: la bactériose, la maladie de la striure brune, acariens vert, la mosaïque, la santé, ou inconnus.

Cette collaboration montre comment :

  • Chargez le https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 modèle de tensorflow Hub
  • Charger le manioc ensemble de données à partir de tensorflow Datasets (TFDS)
  • Classez les images de feuilles de manioc en 4 catégories distinctes de maladies du manioc ou comme saines ou inconnues.
  • Évaluer la précision du classificateur et examiner la robustesse du modèle est appliqué à partir d'images de domaine.

Importations et configuration

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Fonction d'assistance pour l'affichage d'exemples

Base de données

La charge de laisser le jeu de données de manioc de TFDS

dataset, info = tfds.load('cassava', with_info=True)

Jetons un coup d'œil aux informations sur l'ensemble de données pour en savoir plus, comme la description et la citation et des informations sur le nombre d'exemples disponibles

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

L'ensemble de données de manioc a des images de feuilles de manioc avec 4 maladies distinctes, ainsi que des feuilles saines de manioc. Le modèle peut prédire toutes ces classes ainsi qu'une sixième classe pour « inconnu » lorsque le modèle n'est pas sûr de sa prédiction.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Avant de pouvoir alimenter le modèle en données, nous devons effectuer un peu de prétraitement. Le modèle attend 224 x 224 images avec des valeurs de canal RVB dans [0, 1]. Normalisons et redimensionnons les images.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Jetons un coup d'œil à quelques exemples de l'ensemble de données

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Modèle

Chargeons le classificateur de TF Hub et obtenons des prédictions et voyons les prédictions du modèle sur quelques exemples

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Évaluation & robustesse

Mesurons la précision de notre classificateur sur une scission de l'ensemble de données. Nous pouvons aussi regarder la robustesse du modèle en évaluant ses performances sur un ensemble de données non-manioc. Pour l' image d'autres ensembles de données de plantes comme iNaturalist ou les haricots, le modèle doit retourner presque toujours inconnu.

Paramètres

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Apprendre encore plus