ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค | ดูรุ่น TF Hub |
โน๊ตบุ๊คนี้แสดงให้เห็นถึงวิธีการใช้ CropNet มันสำปะหลังโรคลักษณนาม โมเดลจาก TensorFlow Hub รูปแบบการจัดประเภทภาพของใบมันสำปะหลังเป็นหนึ่งใน 6 ชั้นเรียน: ทำลายแบคทีเรียโรคริ้วสีน้ำตาล, สีเขียวไรโรคโมเสคที่มีสุขภาพดีหรือไม่รู้จัก
colab นี้สาธิตวิธี:
- โหลด https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 โมเดลจาก TensorFlow Hub
- โหลด มันสำปะหลัง ชุดข้อมูลจาก TensorFlow ชุดข้อมูล (TFDS)
- จำแนกภาพใบมันสำปะหลังออกเป็น 4 ประเภทโรคมันสำปะหลังที่แตกต่างกัน หรือมีสุขภาพดีหรือไม่เป็นที่รู้จัก
- ประเมินความถูกต้องของการจําแนกและดูที่วิธีการที่มีประสิทธิภาพรุ่นคือเมื่อนำไปใช้ออกจากภาพโดเมน
นำเข้าและตั้งค่า
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
ฟังก์ชั่น Helper สำหรับแสดงตัวอย่าง
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
ชุดข้อมูล
โหลด Let 's ชุดข้อมูลจากมันสำปะหลัง TFDS
dataset, info = tfds.load('cassava', with_info=True)
ลองมาดูข้อมูลชุดข้อมูลเพื่อเรียนรู้เพิ่มเติมกัน เช่น คำอธิบายและการอ้างอิงและข้อมูลเกี่ยวกับจำนวนตัวอย่างที่มี
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
ชุดข้อมูลมันสำปะหลังมีภาพของใบมันสำปะหลังที่มี 4 โรคที่แตกต่างกันเช่นเดียวกับใบมันสำปะหลังที่มีสุขภาพดี โมเดลสามารถทำนายคลาสทั้งหมดเหล่านี้ได้เช่นเดียวกับคลาสที่หกสำหรับ "ไม่ทราบ" เมื่อโมเดลไม่มั่นใจในการคาดการณ์
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
ก่อนที่เราจะสามารถป้อนข้อมูลไปยังแบบจำลองได้ เราต้องทำการประมวลผลล่วงหน้าเล็กน้อย โมเดลคาดหวังภาพ 224 x 224 พร้อมค่าช่องสัญญาณ RGB ใน [0, 1] มาทำให้เป็นมาตรฐานและปรับขนาดรูปภาพกันเถอะ
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
ลองดูตัวอย่างบางส่วนจากชุดข้อมูล
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
แบบอย่าง
มาโหลดตัวแยกประเภทจาก TF Hub และรับการทำนายและดูการทำนายของแบบจำลองอยู่ในตัวอย่างบางส่วน
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
การประเมินและความทนทาน
ลองวัดความถูกต้องของลักษณนามของเราเกี่ยวกับการแยกของชุดข้อมูลที่ นอกจากนี้เรายังสามารถดูความแข็งแรงของรูปแบบโดยการประเมินผลการปฏิบัติงานในชุดข้อมูลที่ไม่ใช่มันสำปะหลัง สำหรับภาพของชุดข้อมูลพืชอื่น ๆ เช่น iNaturalist หรือถั่ว, รุ่นเกือบตลอดเวลาควรกลับไม่รู้จัก
พารามิเตอร์
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
เรียนรู้เพิ่มเติม
- เรียนรู้เพิ่มเติมเกี่ยวกับรูปแบบใน TensorFlow Hub: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- เรียนรู้วิธีการสร้างภาพที่กำหนดเองลักษณนามทำงานบนโทรศัพท์มือถือที่มี ML Kit กับ รุ่น Lite TensorFlow ของรุ่นนี้