הצג באתר TensorFlow.org | הפעל בגוגל קולאב | הצג ב-GitHub | הורד מחברת | ראה דגם TF Hub |
מופעים מחברת זו כיצד להשתמש CropNet מסווג מחל קסאווה המודל מרכזת TensorFlow. התמונות המסווגות מודל של עלים קסאווה לאחת 6 כיתות: שידפון החיידקים, מחל פס חום, קרדית ירוקה, מחלת פסיפס, בריא, או לא ידוע.
קולב זה מדגים כיצד:
- טען את https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 מודל מ TensorFlow Hub
- טען את קסאווה במערך מן TensorFlow מערכי נתונים (TFDS)
- סווגו תמונות של עלי קסאווה ל-4 קטגוריות שונות של מחלת קסאווה או כבריאים או לא ידועים.
- להעריך את הדיוק של המסווג ולהסתכל איך הוא המודל חזק כאשר מוחל מתוך תמונות תחום.
יבוא והגדרה
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
פונקציית עוזר להצגת דוגמאות
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
מערך נתונים
עומס בסיס הנתונים קסאווה בואו מן TFDS
dataset, info = tfds.load('cassava', with_info=True)
בואו נסתכל על המידע של מערך הנתונים כדי ללמוד עליו יותר, כמו התיאור והציטוט ומידע על כמה דוגמאות זמינות
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
מערך נתון של קסאווה יש תמונות של עלי קסבה עם 4 מחלות ברורות כמו גם עלים קסאווה בריאים. המודל יכול לחזות את כל המחלקות הללו כמו גם מחלקה שישית עבור "לא ידוע" כאשר המודל אינו בטוח בתחזית שלו.
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
לפני שנוכל להזין את הנתונים למודל, עלינו לבצע מעט עיבוד מקדים. המודל מצפה לתמונות של 224 x 224 עם ערכי ערוץ RGB ב-[0, 1]. בואו ננרמל ונשנה את גודל התמונות.
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
בואו נסתכל על כמה דוגמאות מתוך מערך הנתונים
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
דֶגֶם
בואו נטען את המסווגן מ-TF Hub ונקבל כמה תחזיות ונראה את התחזיות של המודל על כמה דוגמאות
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
הערכה וחוסן
בואו למדוד את מידת הדיוק של מסווג שלנו על פיצול של בסיס הנתונים. אנחנו גם יכולים להסתכל על חוסנו של המודל על ידי והערכת הביצועים שלו על בסיס הנתונים הלא-קסבה. לצורך הדימוי של מערכי נתונים צמחיים אחרים כמו iNaturalist או שעועית, המודל צריך כמעט תמיד לחזור ידוע.
פרמטרים
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
למד עוד
- למידע נוסף על המודל על TensorFlow Hub: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- למד כיצד לבנות ריצה מסווגת תמונה מותאמת אישית בטלפון נייד עם ערכת ML עם גרסת לייט TensorFlow של הדגם הזה .