TensorFlow.org'da görüntüleyin | Google Colab'da çalıştırın | GitHub'da görüntüle | Not defterini indir | TF Hub modeline bakın |
Bu defter gösterileri nasıl CropNet kullanmak manyok hastalığı sınıflandırıcı TensorFlow Hub'dan modeli. Bakteriyel yanıklık, kahverengi çizgi hastalığı, yeşil akarı, mozaik hastalığı, sağlıklı veya bilinmeyen: 6 sınıftan birine manyok yaprakların modeli sınıflandırır görüntüler.
Bu ortak çalışma şunların nasıl yapılacağını gösterir:
- Yük https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2 TensorFlow Hub'dan modeli
- Yük manyok TensorFlow Veri kümeleri gelen veri kümesi (TFDS)
- Manyok yapraklarının görüntülerini 4 farklı manyok hastalığı kategorisine veya sağlıklı veya bilinmeyen olarak sınıflandırın.
- Malı görüntüler dışı uygulandığında modeli ne kadar sağlam olarak sınıflandırıcı ve görünüm doğruluğunu değerlendirin.
İthalat ve kurulum
pip install matplotlib==3.2.2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
Örnekleri görüntülemek için yardımcı fonksiyon
def plot(examples, predictions=None):
# Get the images, labels, and optionally predictions
images = examples['image']
labels = examples['label']
batch_size = len(images)
if predictions is None:
predictions = batch_size * [None]
# Configure the layout of the grid
x = np.ceil(np.sqrt(batch_size))
y = np.ceil(batch_size / x)
fig = plt.figure(figsize=(x * 6, y * 7))
for i, (image, label, prediction) in enumerate(zip(images, labels, predictions)):
# Render the image
ax = fig.add_subplot(x, y, i+1)
ax.imshow(image, aspect='auto')
ax.grid(False)
ax.set_xticks([])
ax.set_yticks([])
# Display the label and optionally prediction
x_label = 'Label: ' + name_map[class_names[label]]
if prediction is not None:
x_label = 'Prediction: ' + name_map[class_names[prediction]] + '\n' + x_label
ax.xaxis.label.set_color('green' if label == prediction else 'red')
ax.set_xlabel(x_label)
plt.show()
veri kümesi
Hadi yük TFDS manyok veri kümesi
dataset, info = tfds.load('cassava', with_info=True)
Bunun hakkında daha fazla bilgi edinmek için, açıklama ve alıntı gibi veri kümesi bilgilerine ve kaç tane örnek bulunduğuna ilişkin bilgilere bir göz atalım.
info
tfds.core.DatasetInfo( name='cassava', full_name='cassava/0.1.0', description=""" Cassava consists of leaf images for the cassava plant depicting healthy and four (4) disease conditions; Cassava Mosaic Disease (CMD), Cassava Bacterial Blight (CBB), Cassava Greem Mite (CGM) and Cassava Brown Streak Disease (CBSD). Dataset consists of a total of 9430 labelled images. The 9430 labelled images are split into a training set (5656), a test set(1885) and a validation set (1889). The number of images per class are unbalanced with the two disease classes CMD and CBSD having 72% of the images. """, homepage='https://www.kaggle.com/c/cassava-disease/overview', data_path='gs://tensorflow-datasets/datasets/cassava/0.1.0', download_size=1.26 GiB, dataset_size=Unknown size, features=FeaturesDict({ 'image': Image(shape=(None, None, 3), dtype=tf.uint8), 'image/filename': Text(shape=(), dtype=tf.string), 'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=5), }), supervised_keys=('image', 'label'), disable_shuffling=False, splits={ 'test': <SplitInfo num_examples=1885, num_shards=4>, 'train': <SplitInfo num_examples=5656, num_shards=8>, 'validation': <SplitInfo num_examples=1889, num_shards=4>, }, citation="""@misc{mwebaze2019icassava, title={iCassava 2019Fine-Grained Visual Categorization Challenge}, author={Ernest Mwebaze and Timnit Gebru and Andrea Frome and Solomon Nsumba and Jeremy Tusubira}, year={2019}, eprint={1908.02900}, archivePrefix={arXiv}, primaryClass={cs.CV} }""", )
Manyok veri kümesi 4 ayrı hastalıklarla manyok yaprakları yanı sıra sağlıklı manyok yaprakların görüntüleri vardır. Model, tahmininden emin olmadığında, tüm bu sınıfların yanı sıra "bilinmeyen" için altıncı sınıfı da tahmin edebilir.
# Extend the cassava dataset classes with 'unknown'
class_names = info.features['label'].names + ['unknown']
# Map the class names to human readable names
name_map = dict(
cmd='Mosaic Disease',
cbb='Bacterial Blight',
cgm='Green Mite',
cbsd='Brown Streak Disease',
healthy='Healthy',
unknown='Unknown')
print(len(class_names), 'classes:')
print(class_names)
print([name_map[name] for name in class_names])
6 classes: ['cbb', 'cbsd', 'cgm', 'cmd', 'healthy', 'unknown'] ['Bacterial Blight', 'Brown Streak Disease', 'Green Mite', 'Mosaic Disease', 'Healthy', 'Unknown']
Verileri modele beslemeden önce biraz ön işleme yapmamız gerekiyor. Model, [0, 1]'de RGB kanal değerlerine sahip 224 x 224 görüntü bekliyor. Görüntüleri normalleştirelim ve yeniden boyutlandıralım.
def preprocess_fn(data):
image = data['image']
# Normalize [0, 255] to [0, 1]
image = tf.cast(image, tf.float32)
image = image / 255.
# Resize the images to 224 x 224
image = tf.image.resize(image, (224, 224))
data['image'] = image
return data
Veri setinden birkaç örneğe bakalım
batch = dataset['validation'].map(preprocess_fn).batch(25).as_numpy_iterator()
examples = next(batch)
plot(examples)
modeli
Sınıflandırıcıyı TF Hub'dan yükleyelim ve bazı tahminler alalım ve modelin tahminlerini birkaç örnek üzerinde görelim.
classifier = hub.KerasLayer('https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2')
probabilities = classifier(examples['image'])
predictions = tf.argmax(probabilities, axis=-1)
plot(examples, predictions)
Değerlendirme ve sağlamlık
En veri kümesinin bir bölünme bizim sınıflandırıcı doğruluğunu ölçmek edelim. Biz de olmayan bir manyok veri kümesi üzerindeki performansını değerlendirerek modelinin sağlamlığı bakabilirsiniz. İNaturalist veya fasulye gibi diğer bitki veri kümelerinin resmi için, model, hemen hemen her zaman bilinmeyen dönmelidir.
parametreler
DATASET = 'cassava'
DATASET_SPLIT = 'test'
BATCH_SIZE = 32
MAX_EXAMPLES = 1000
def label_to_unknown_fn(data):
data['label'] = 5 # Override label to unknown.
return data
# Preprocess the examples and map the image label to unknown for non-cassava datasets.
ds = tfds.load(DATASET, split=DATASET_SPLIT).map(preprocess_fn).take(MAX_EXAMPLES)
dataset_description = DATASET
if DATASET != 'cassava':
ds = ds.map(label_to_unknown_fn)
dataset_description += ' (labels mapped to unknown)'
ds = ds.batch(BATCH_SIZE)
# Calculate the accuracy of the model
metric = tf.keras.metrics.Accuracy()
for examples in ds:
probabilities = classifier(examples['image'])
predictions = tf.math.argmax(probabilities, axis=-1)
labels = examples['label']
metric.update_state(labels, predictions)
print('Accuracy on %s: %.2f' % (dataset_description, metric.result().numpy()))
Accuracy on cassava: 0.88
Daha fazla bilgi edin
- TensorFlow Hub üzerindeki modeli hakkında daha fazla bilgi: https://tfhub.dev/google/cropnet/classifier/cassava_disease_V1/2
- Bir cep telefonu özel bir görüntü sınıflandırıcı çalışan inşa öğrenin ML Kit ile bu modelin TensorFlow Lite sürümü .