Visualizza su TensorFlow.org | Esegui in Google Colab | Visualizza su GitHub | Scarica taccuino | Vedi il modello del mozzo TF |
Questa mostra notebook come usare il CropNet malattia classificatore manioca modello da tensorflow Hub. Le modello classifica immagini di foglie di manioca in una delle 6 classi: ruggine batterica, malattia striscia marrone, acari verde, la malattia del mosaico, in buona salute, o sconosciute.
Questa collaborazione mostra come:
- Caricare il https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 modello da tensorflow Hub
- Caricare la manioca set di dati da tensorflow Datasets (TFDS)
- Classifica le immagini delle foglie di manioca in 4 distinte categorie di malattie della manioca o come sane o sconosciute.
- Valutare l'accuratezza del classificatore e sguardo a come robusto il modello è quando applicata a di immagini di dominio.
Importazioni e configurazione
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
Funzione di supporto per la visualizzazione di esempi
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
set di dati
Carico di lasciare che il set di dati da manioca TFDS
dataset, info = tfds.load('cassava', with_info=True)
Diamo un'occhiata alle informazioni sul set di dati per saperne di più, come la descrizione e la citazione e le informazioni su quanti esempi sono disponibili
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
Il set di dati manioca ha immagini di foglie di manioca con 4 malattie distinte e foglie di manioca sani. Il modello può prevedere tutte queste classi e la sesta classe per "sconosciuto" quando il modello non è sicuro della sua previsione.
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
Prima di poter fornire i dati al modello, è necessario eseguire un po' di pre-elaborazione. Il modello prevede immagini 224 x 224 con valori del canale RGB in [0, 1]. Normalizziamo e ridimensioniamo le immagini.
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
Diamo un'occhiata ad alcuni esempi dal set di dati
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
Modello
Carichiamo il classificatore da TF Hub e otteniamo alcune previsioni e vediamo le previsioni del modello su alcuni esempi
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
Valutazione e robustezza
Diamo misurare l'accuratezza del nostro classificatore su una scissione del set di dati. Possiamo anche guardare la robustezza del modello valutando le sue prestazioni su un set di dati non manioca. Per l'immagine di altre serie di dati vegetali come iNaturalist o fagioli, il modello dovrebbe quasi sempre tornare sconosciuta.
Parametri
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
Scopri di più
- Ulteriori informazioni sul modello sul tensorflow Hub: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- Imparare a costruire un'immagine personalizzata classificatore in esecuzione su un telefono cellulare con ML Kit con la versione Lite tensorflow di questo modello .