CropNet: Deteksi Penyakit Singkong

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan Lihat model TF Hub

Notebook menunjukkan ini bagaimana menggunakan CropNet singkong penyakit classifier Model dari TensorFlow Hub. Model mengklasifikasikan gambar daun singkong menjadi salah satu dari 6 kelas: hawar bakteri, penyakit beruntun coklat, tungau hijau, penyakit mosaik, sehat, atau tidak diketahui.

Colab ini menunjukkan cara:

  • Memuat https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 Model dari TensorFlow Hub
  • Memuat singkong dataset dari TensorFlow dataset (TFDS)
  • Klasifikasikan citra daun singkong menjadi 4 kategori penyakit singkong yang berbeda atau sehat atau tidak diketahui.
  • Mengevaluasi akurasi dari classifier dan melihat bagaimana kuat model ini bila diterapkan dari gambar domain.

Impor dan pengaturan

pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub

Fungsi pembantu untuk menampilkan contoh

Himpunan data

Mari beban singkong dataset dari TFDS

dataset, info = tfds.load('cassava', with_info=True)

Mari kita lihat info dataset untuk mempelajarinya lebih lanjut, seperti deskripsi dan kutipan serta informasi tentang berapa banyak contoh yang tersedia

info
tfds.core.DatasetInfo(
    name='cassava',
    full_name='cassava/0.1.0',
    description="""
    Cassava consists of leaf images for the cassava plant depicting healthy and
    four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial
    Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD).
    Dataset consists of a total of 9430 labelled images.
    The 9430 labelled images are split into a training set (5656), a test set(1885)
    and a validation set (1889). The number of images per class are unbalanced with
    the two disease classes CMD and CBSD having 72% of the images.
    """,
    homepage='https://www.kaggle.com/c/cassava-disease/overview',
    data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0',
    download_size=1.26 GiB,
    dataset_size=Unknown size,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=tf.uint8),
        'image/filename': Text(shape=(), dtype=tf.string),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'test': <SplitInfo num_examples=1885, num_shards=4>,
        'train': <SplitInfo num_examples=5656, num_shards=8>,
        'validation': <SplitInfo num_examples=1889, num_shards=4>,
    },
    citation="""@misc{mwebaze2019icassava,
        title={iCassava 2019Fine-Grained Visual Categorization Challenge},
        author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira},
        year={2019},
        eprint={1908.02900},
        archivePrefix={arXiv},
        primaryClass={cs.CV}
    }""",
)

Singkong dataset memiliki gambar daun singkong dengan 4 penyakit yang berbeda serta daun singkong yang sehat. Model dapat memprediksi semua kelas ini serta kelas keenam untuk "tidak diketahui" ketika model tidak yakin dengan prediksinya.

# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']

# Map the class names to human readable names
name_map = dict(
    cmd='Mosaic Disease',
    cbb='Bacterial Blight',
    cgm='Green Mite',
    cbsd='Brown Streak Disease',
    healthy='Healthy',
    unknown='Unknown')

print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes:
['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown']
['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']

Sebelum kita dapat memasukkan data ke model, kita perlu melakukan sedikit preprocessing. Model mengharapkan gambar 224 x 224 dengan nilai saluran RGB di [0, 1]. Mari normalkan dan ubah ukuran gambar.

def preprocess_fn(data):
  image = data['image']

  # Normalize [0, 255] to [0, 1]
  image = tf.cast(image, tf.float32)
  image = image / 255.

  # Resize the images to 224 x 224
  image = tf.image.resize(image, (224, 224))

  data['image'] = image
  return data

Mari kita lihat beberapa contoh dari kumpulan data

batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)

png

Model

Mari kita muat classifier dari TF Hub dan dapatkan beberapa prediksi dan lihat prediksi model pada beberapa contoh

classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)

png

Evaluasi & ketahanan

Mari kita mengukur akurasi classifier kami pada perpecahan dari dataset. Kami juga dapat melihat kekokohan model dengan mengevaluasi kinerjanya pada dataset non-singkong. Untuk gambar dataset tanaman lain seperti iNaturalist atau kacang-kacangan, model harus hampir selalu kembali tidak diketahui.

Parameter

def label_to_unknown_fn(data):
  data['label'] = 5  # Override label to unknown.
  return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
  ds = ds.map(label_to_unknown_fn)
  dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)

# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
  probabilities = classifier(examples['image'])
  predictions = tf.math.argmax(probabilities, axis=-1)
  labels = examples['label']
  metric.update_state(labels, predictions)

print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88

Belajarlah lagi