Voir sur TensorFlow.org | Exécuter dans Google Colab | Voir sur GitHub | Télécharger le cahier | Voir le modèle TF Hub |
Ce bloc - notes montre comment utiliser le CropNet classificateur de la maladie du manioc modèle de tensorflow Hub. Le modèle classifie les images de feuilles de manioc dans l' une des 6 classes: la bactériose, la maladie de la striure brune, acariens vert, la mosaïque, la santé, ou inconnus.
Cette collaboration montre comment :
- Chargez le https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 modèle de tensorflow Hub
- Charger le manioc ensemble de données à partir de tensorflow Datasets (TFDS)
- Classez les images de feuilles de manioc en 4 catégories distinctes de maladies du manioc ou comme saines ou inconnues.
- Évaluer la précision du classificateur et examiner la robustesse du modèle est appliqué à partir d'images de domaine.
Importations et configuration
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
Fonction d'assistance pour l'affichage d'exemples
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
Base de données
La charge de laisser le jeu de données de manioc de TFDS
dataset, info = tfds.load('cassava', with_info=True)
Jetons un coup d'œil aux informations sur l'ensemble de données pour en savoir plus, comme la description et la citation et des informations sur le nombre d'exemples disponibles
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
L'ensemble de données de manioc a des images de feuilles de manioc avec 4 maladies distinctes, ainsi que des feuilles saines de manioc. Le modèle peut prédire toutes ces classes ainsi qu'une sixième classe pour « inconnu » lorsque le modèle n'est pas sûr de sa prédiction.
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
Avant de pouvoir alimenter le modèle en données, nous devons effectuer un peu de prétraitement. Le modèle attend 224 x 224 images avec des valeurs de canal RVB dans [0, 1]. Normalisons et redimensionnons les images.
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
Jetons un coup d'œil à quelques exemples de l'ensemble de données
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
Modèle
Chargeons le classificateur de TF Hub et obtenons des prédictions et voyons les prédictions du modèle sur quelques exemples
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
Évaluation & robustesse
Mesurons la précision de notre classificateur sur une scission de l'ensemble de données. Nous pouvons aussi regarder la robustesse du modèle en évaluant ses performances sur un ensemble de données non-manioc. Pour l' image d'autres ensembles de données de plantes comme iNaturalist ou les haricots, le modèle doit retourner presque toujours inconnu.
Paramètres
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
Apprendre encore plus
- En savoir plus sur le modèle sur tensorflow Hub: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- Apprenez comment construire un classificateur image personnalisée fonctionnant sur un téléphone mobile avec ML Kit avec la Version tensorflow Lite de ce modèle .