CropNet: Detección de la enfermedad de la mandioca

Ver en TensorFlow.org Ejecutar en Google Colab Ver en GitHub Descargar cuaderno Ver modelo TF Hub

Este portátil se muestra cómo utilizar el CropNet clasificador enfermedad yuca modelo de TensorFlow concentradores. Las imágenes clasifica modelo de las hojas de yuca en una de las 6 clases: tizón bacteriano, enfermedad de la raya marrón, verde ácaros, la enfermedad del mosaico, saludable, o desconocidos.

Este colab demuestra cómo:

  • Cargar el https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 modelo de TensorFlow Hub
  • Cargar la yuca conjunto de datos de TensorFlow conjuntos de datos (TFDS)
  • Clasifique las imágenes de hojas de yuca en 4 categorías distintas de enfermedades de la yuca o como sanas o desconocidas.
  • Evaluar la precisión del clasificador y la mirada a la solidez del modelo es cuando se aplica a partir de imágenes de dominio.

Importaciones y montaje

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Función auxiliar para mostrar ejemplos.

Conjunto de datos

Vamos a la carga el conjunto de datos de yuca TFDS

dataset, info = tfds.load('cassava', with_info=True)

Echemos un vistazo a la información del conjunto de datos para obtener más información al respecto, como la descripción, la cita y la información sobre cuántos ejemplos están disponibles.

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

El conjunto de datos yuca tiene imágenes de hojas de yuca con 4 diferentes enfermedades, así como las hojas de yuca sanas. El modelo puede predecir todas estas clases, así como la sexta clase para "desconocido" cuando el modelo no confía en su predicción.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Antes de que podamos enviar los datos al modelo, necesitamos hacer un poco de preprocesamiento. El modelo espera imágenes de 224 x 224 con valores de canal RGB en [0, 1]. Normalicemos y redimensionemos las imágenes.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Echemos un vistazo a algunos ejemplos del conjunto de datos.

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Modelo

Carguemos el clasificador de TF Hub y obtengamos algunas predicciones y veamos las predicciones del modelo en algunos ejemplos.

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Evaluación y robustez

Vamos a medir la precisión de nuestra clasificador en una fracción del conjunto de datos. También podemos ver la robustez del modelo mediante la evaluación de su rendimiento en un conjunto de datos no yuca. Para la imagen de otros conjuntos de datos de plantas como iNaturalist o frijoles, el modelo debe devolver casi siempre desconocida.

Parámetros

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Aprende más