TensorFlow.org पर देखें | Google Colab में चलाएं | GitHub पर स्रोत देखें | नोटबुक डाउनलोड करें |
सेट अप
import numpy as np
import tensorflow as tf
from tensorflow import keras
परिचय
स्थानांतरण सीखने एक समस्या पर सीखा सुविधाओं लेने, और एक नया, समान समस्या पर उन्हें लाभ के होते हैं। उदाहरण के लिए, एक मॉडल की विशेषताएं जिसने रैकून की पहचान करना सीख लिया है, एक मॉडल को किक-स्टार्ट करने के लिए उपयोगी हो सकता है, जिसका उद्देश्य तनुकियों की पहचान करना है।
स्थानांतरण सीखना आमतौर पर उन कार्यों के लिए किया जाता है जहां आपके डेटासेट में पूर्ण पैमाने के मॉडल को खरोंच से प्रशिक्षित करने के लिए बहुत कम डेटा होता है।
डीप लर्निंग के संदर्भ में ट्रांसफर लर्निंग का सबसे आम अवतार निम्नलिखित वर्कफ़्लो है:
- पहले से प्रशिक्षित मॉडल से परतें लें।
- उन्हें फ्रीज करें, ताकि भविष्य के प्रशिक्षण दौरों के दौरान उनके पास मौजूद किसी भी जानकारी को नष्ट करने से बचा जा सके।
- जमी हुई परतों के ऊपर कुछ नई, प्रशिक्षित करने योग्य परतें जोड़ें। वे नए डेटासेट पर पुरानी सुविधाओं को भविष्यवाणियों में बदलना सीखेंगे।
- अपने डेटासेट पर नई परतों को प्रशिक्षित करें।
एक अंतिम, वैकल्पिक कदम, ठीक करने, पूरे मॉडल आप ऊपर प्राप्त (या इसे का हिस्सा) unfreezing के होते हैं जो, और एक बहुत कम सीखने की दर के साथ नए डेटा पर यह पुन: प्रशिक्षण है। यह संभावित रूप से नए डेटा के लिए पूर्व-प्रशिक्षित सुविधाओं को बढ़ाकर, सार्थक सुधार प्राप्त कर सकता है।
सबसे पहले, हम Keras की चर्चा करेंगे trainable
जो सबसे हस्तांतरण सीखने और ठीक करने workflows underlies विस्तार से एपीआई,।
फिर, हम इमेजनेट डेटासेट पर पूर्व-प्रशिक्षित मॉडल लेकर और कागल "कैट्स बनाम डॉग्स" वर्गीकरण डेटासेट पर इसे फिर से प्रशिक्षित करके विशिष्ट वर्कफ़्लो प्रदर्शित करेंगे।
इस से अनुकूलित है अजगर के साथ दीप लर्निंग से 2016 के ब्लॉग पोस्ट "बहुत कम डेटा का उपयोग कर शक्तिशाली छवि वर्गीकरण मॉडल के निर्माण" ।
बर्फ़ीली परतों: समझने trainable
विशेषता
परतों और मॉडलों में तीन भार विशेषताएँ होती हैं:
-
weights
परत के सभी वजन चर की सूची है। -
trainable_weights
उन है कि अद्यतन किया जा के लिए होती हैं (ढाल वंश के माध्यम से) प्रशिक्षण के दौरान नुकसान को कम करने की सूची है। -
non_trainable_weights
उन है कि प्रशिक्षित किया जा करने के लिए नहीं कर रहे हैं की सूची है। आमतौर पर वे फॉरवर्ड पास के दौरान मॉडल द्वारा अपडेट किए जाते हैं।
उदाहरण: Dense
परत 2 trainable वजन है (कर्नेल और पूर्वाग्रह)
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 2 non_trainable_weights: 0
सामान्य तौर पर, सभी भार प्रशिक्षण योग्य भार होते हैं। केवल अंतर्निहित परत है कि गैर trainable वजन है BatchNormalization
परत। यह प्रशिक्षण के दौरान अपने इनपुट के माध्य और भिन्नता का ट्रैक रखने के लिए गैर-प्रशिक्षित भार का उपयोग करता है। कैसे, अपने स्वयं के कस्टम परतों में गैर trainable वजन का उपयोग देखने के लिए जानने के लिए खरोंच से नई परतें लिखने के लिए गाइड ।
उदाहरण: BatchNormalization
परत 2 trainable वजन और 2 गैर trainable भार है
layer = keras.layers.BatchNormalization()
layer.build((None, 4)) # Create the weights
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 4 trainable_weights: 2 non_trainable_weights: 2
परतें और मॉडल भी एक बूलियन विशेषता सुविधा trainable
। इसका मूल्य बदला जा सकता है। स्थापना layer.trainable
को False
गैर trainable को trainable से सभी परत के वजन ले जाता है। यह कहा जाता है "ठंड" परत: एक जमे हुए परत के राज्य प्रशिक्षण के दौरान अपडेट नहीं किया जाएगा (या तो जब साथ प्रशिक्षण fit()
या जब कि पर निर्भर करता है किसी भी कस्टम पाश के साथ प्रशिक्षण trainable_weights
ढाल अद्यतन लागू करने)।
उदाहरण: स्थापित करने trainable
को False
layer = keras.layers.Dense(3)
layer.build((None, 4)) # Create the weights
layer.trainable = False # Freeze the layer
print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))
weights: 2 trainable_weights: 0 non_trainable_weights: 2
जब एक प्रशिक्षित वजन गैर-प्रशिक्षित हो जाता है, तो प्रशिक्षण के दौरान इसका मूल्य अपडेट नहीं किया जाता है।
# Make a model with 2 layers
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])
# Freeze the first layer
layer1.trainable = False
# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()
# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))
# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()
np.testing.assert_allclose(
initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
initial_layer1_weights_values[1], final_layer1_weights_values[1]
)
1/1 [==============================] - 1s 640ms/step - loss: 0.0945
भ्रमित न layer.trainable
तर्क के साथ विशेषता training
में layer.__call__()
(जो नियंत्रित होता है कि परत अनुमान मोड या प्रशिक्षण मोड में अपनी फॉरवर्ड पास चलाना चाहिए)। अधिक जानकारी के लिए, Keras पूछे जाने वाले प्रश्न ।
की पुनरावर्ती सेटिंग trainable
विशेषता
आप सेट करते हैं trainable = False
एक मॉडल पर या sublayers है कि किसी भी स्तर पर, सभी बच्चों परतों साथ ही गैर-trainable हो जाते हैं।
उदाहरण:
inner_model = keras.Sequential(
[
keras.Input(shape=(3,)),
keras.layers.Dense(3, activation="relu"),
keras.layers.Dense(3, activation="relu"),
]
)
model = keras.Sequential(
[keras.Input(shape=(3,)), inner_model, keras.layers.Dense(3, activation="sigmoid"),]
)
model.trainable = False # Freeze the outer model
assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False # `trainable` is propagated recursively
विशिष्ट स्थानांतरण-शिक्षण कार्यप्रवाह
यह हमें इस बात की ओर ले जाता है कि केरस में एक विशिष्ट स्थानांतरण सीखने के वर्कफ़्लो को कैसे लागू किया जा सकता है:
- एक बेस मॉडल को इंस्टेंट करें और उसमें पूर्व-प्रशिक्षित वेट लोड करें।
- स्थापना करके बेस मॉडल में सभी परतों फ्रीज
trainable = False
। - बेस मॉडल से एक (या कई) लेयर्स के आउटपुट के ऊपर एक नया मॉडल बनाएं।
- अपने नए मॉडल को अपने नए डेटासेट पर प्रशिक्षित करें।
ध्यान दें कि एक वैकल्पिक, अधिक हल्का वर्कफ़्लो भी हो सकता है:
- एक बेस मॉडल को इंस्टेंट करें और उसमें पूर्व-प्रशिक्षित वेट लोड करें।
- इसके माध्यम से अपना नया डेटासेट चलाएं और बेस मॉडल से एक (या कई) परतों के आउटपुट को रिकॉर्ड करें। यह सुविधा निष्कर्षण कहा जाता है।
- उस आउटपुट का उपयोग एक नए, छोटे मॉडल के लिए इनपुट डेटा के रूप में करें।
उस दूसरे वर्कफ़्लो का एक प्रमुख लाभ यह है कि आप अपने डेटा पर केवल एक बार बेस मॉडल चलाते हैं, न कि प्रशिक्षण के प्रति युग में एक बार। तो यह बहुत तेज और सस्ता है।
हालाँकि, उस दूसरे वर्कफ़्लो के साथ एक समस्या यह है कि यह आपको प्रशिक्षण के दौरान अपने नए मॉडल के इनपुट डेटा को गतिशील रूप से संशोधित करने की अनुमति नहीं देता है, जो उदाहरण के लिए डेटा वृद्धि करते समय आवश्यक है। ट्रांसफर लर्निंग का उपयोग आमतौर पर उन कार्यों के लिए किया जाता है, जब आपके नए डेटासेट में स्क्रैच से पूर्ण-स्केल मॉडल को प्रशिक्षित करने के लिए बहुत कम डेटा होता है, और ऐसे परिदृश्यों में डेटा वृद्धि बहुत महत्वपूर्ण होती है। तो इस प्रकार, हम पहले वर्कफ़्लो पर ध्यान केंद्रित करेंगे।
यहां बताया गया है कि केरस में पहला वर्कफ़्लो कैसा दिखता है:
सबसे पहले, पूर्व-प्रशिक्षित भार के साथ एक बेस मॉडल को इंस्टेंट करें।
base_model = keras.applications.Xception(
weights='imagenet', # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False) # Do not include the ImageNet classifier at the top.
फिर, बेस मॉडल को फ्रीज करें।
base_model.trainable = False
शीर्ष पर एक नया मॉडल बनाएं।
inputs = keras.Input(shape=(150, 150, 3))
# We make sure that the base_model is running in inference mode here,
# by passing `training=False`. This is important for fine-tuning, as you will
# learn in a few paragraphs.
x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
# A Dense classifier with a single unit (binary classification)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
नए डेटा पर मॉडल को प्रशिक्षित करें।
model.compile(optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
model.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)
फ़ाइन ट्यूनिंग
एक बार जब आपका मॉडल नए डेटा में परिवर्तित हो जाता है, तो आप आधार मॉडल के सभी या उसके हिस्से को अनफ़्रीज़ करने का प्रयास कर सकते हैं और बहुत कम सीखने की दर के साथ पूरे मॉडल को एंड-टू-एंड फिर से प्रशिक्षित कर सकते हैं।
यह एक वैकल्पिक अंतिम चरण है जो संभावित रूप से आपको वृद्धिशील सुधार दे सकता है। यह संभावित रूप से त्वरित ओवरफिटिंग का कारण भी बन सकता है - इसे ध्यान में रखें।
यह बाद जमे हुए परतों के साथ मॉडल अभिसरण करने के लिए प्रशिक्षित किया गया है केवल इस कदम करने के लिए महत्वपूर्ण है। यदि आप पूर्व-प्रशिक्षित सुविधाओं को रखने वाली प्रशिक्षित परतों के साथ बेतरतीब ढंग से आरंभिक प्रशिक्षित परतों को मिलाते हैं, तो यादृच्छिक रूप से आरंभ की गई परतें प्रशिक्षण के दौरान बहुत बड़े ग्रेडिएंट अपडेट का कारण बनेंगी, जो आपकी पूर्व-प्रशिक्षित सुविधाओं को नष्ट कर देगी।
इस स्तर पर बहुत कम सीखने की दर का उपयोग करना भी महत्वपूर्ण है, क्योंकि आप प्रशिक्षण के पहले दौर की तुलना में बहुत बड़े मॉडल का प्रशिक्षण दे रहे हैं, एक ऐसे डेटासेट पर जो आमतौर पर बहुत छोटा होता है। नतीजतन, यदि आप बड़े वजन वाले अपडेट लागू करते हैं तो आपको बहुत जल्दी ओवरफिट होने का खतरा होता है। यहां, आप केवल पूर्व-प्रशिक्षित वज़न को वृद्धिशील तरीके से पढ़ना चाहते हैं।
पूरे बेस मॉडल के फाइन-ट्यूनिंग को लागू करने का तरीका इस प्रकार है:
# Unfreeze the base model
base_model.trainable = True
# It's important to recompile your model after you make any changes
# to the `trainable` attribute of any inner layer, so that your changes
# are take into account
model.compile(optimizer=keras.optimizers.Adam(1e-5), # Very low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()])
# Train end-to-end. Be careful to stop before you overfit!
model.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)
बारे में महत्वपूर्ण सूचना compile()
और trainable
कॉलिंग compile()
एक मॉडल का मतलब है पर इस मॉडल का व्यवहार "फ्रीज"। इसका मतलब है कि trainable
समय मॉडल संकलित किया गया है पर विशेषता मान, इस मॉडल का जीवन भर संरक्षित किया जाना चाहिए जब तक compile
फिर कहा जाता है। इसलिए, अगर आप किसी भी बदल trainable
मूल्य, यकीन है कि कॉल करने के लिए बनाने के compile()
फिर से अपने मॉडल पर के लिए अपने परिवर्तनों को ध्यान में रखा जाना करने के लिए।
बारे में महत्वपूर्ण सूचनाएं BatchNormalization
परत
कई छवि मॉडल शामिल BatchNormalization
परतों। वह परत हर कल्पनीय गिनती पर एक विशेष मामला है। यहाँ कुछ बातों का ध्यान रखना है।
-
BatchNormalization
2 गैर trainable वजन कि प्रशिक्षण के दौरान अद्यतन में शामिल है। ये वेरिएबल हैं जो इनपुट के माध्य और भिन्नता को ट्रैक करते हैं। - जब आप सेट
bn_layer.trainable = False
,BatchNormalization
परत अनुमान मोड में चलेगा, और उसके मतलब और विचरण आंकड़ों को अपडेट नहीं होंगे। यह सामान्य रूप से अन्य परतों, के रूप में के लिए मामला नहीं है वजन trainability और निष्कर्ष / प्रशिक्षण मोड दो ओर्थोगोनल अवधारणाओं हैं । लेकिन दो के मामले में बंधे होते हैंBatchNormalization
परत। - जब आप एक मॉडल है कि शामिल नरम कर देना
BatchNormalization
आदेश ठीक करने करने के लिए परतों, आप रखना चाहिएBatchNormalization
पास करके निष्कर्ष मोड में परतोंtraining=False
जब बेस मॉडल करार दिया। अन्यथा गैर-प्रशिक्षित भारों पर लागू किए गए अपडेट मॉडल द्वारा सीखी गई बातों को अचानक नष्ट कर देंगे।
आप इस गाइड के अंत में एंड-टू-एंड उदाहरण में इस पैटर्न को क्रिया में देखेंगे।
एक कस्टम प्रशिक्षण लूप के साथ सीखने और फ़ाइन-ट्यूनिंग को स्थानांतरित करें
के बजाय अगर fit()
, आप अपने खुद के निम्न स्तर के प्रशिक्षण पाश का उपयोग कर रहे हैं, कार्यप्रवाह रहता है मूलतः एक ही। आपको सावधान रहना चाहिए के लिए एक ही खाते में सूची ले model.trainable_weights
जब ढाल अद्यतन को लागू करने:
# Create base model
base_model = keras.applications.Xception(
weights='imagenet',
input_shape=(150, 150, 3),
include_top=False)
# Freeze base model
base_model.trainable = False
# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()
# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
# Open a GradientTape.
with tf.GradientTape() as tape:
# Forward pass.
predictions = model(inputs)
# Compute the loss value for this batch.
loss_value = loss_fn(targets, predictions)
# Get gradients of loss wrt the *trainable* weights.
gradients = tape.gradient(loss_value, model.trainable_weights)
# Update the weights of the model.
optimizer.apply_gradients(zip(gradients, model.trainable_weights))
इसी तरह फाइन-ट्यूनिंग के लिए।
एक एंड-टू-एंड उदाहरण: एक बिल्ली बनाम कुत्तों के डेटासेट पर एक छवि वर्गीकरण मॉडल को फाइन-ट्यूनिंग करना
इन अवधारणाओं को मजबूत करने के लिए, आइए आपको एक ठोस एंड-टू-एंड ट्रांसफर लर्निंग और फाइन-ट्यूनिंग उदाहरण के माध्यम से चलते हैं। हम इमेजनेट पर पूर्व-प्रशिक्षित एक्ससेप्शन मॉडल को लोड करेंगे, और कागल "बिल्लियों बनाम कुत्तों" वर्गीकरण डेटासेट पर इसका उपयोग करेंगे।
डेटा प्राप्त करना
सबसे पहले, आइए TFDS का उपयोग करके कैट्स बनाम डॉग्स डेटासेट प्राप्त करें। आप अपने खुद के डाटासेट है, तो आप शायद उपयोगिता का उपयोग करना चाहेंगे tf.keras.preprocessing.image_dataset_from_directory
समान लेबल डाटासेट वर्ग विशेष फ़ोल्डरों में दायर डिस्क पर छवियों का एक सेट से वस्तुओं उत्पन्न करने के लिए।
बहुत छोटे डेटासेट के साथ काम करते समय स्थानांतरण सीखना सबसे उपयोगी होता है। अपने डेटासेट को छोटा रखने के लिए, हम प्रशिक्षण के लिए मूल प्रशिक्षण डेटा का 40% (25,000 चित्र), सत्यापन के लिए 10% और परीक्षण के लिए 10% का उपयोग करेंगे।
import tensorflow_datasets as tfds
tfds.disable_progress_bar()
train_ds, validation_ds, test_ds = tfds.load(
"cats_vs_dogs",
# Reserve 10% for validation and 10% for test
split=["train[:40%]", "train[40%:50%]", "train[50%:60%]"],
as_supervised=True, # Include labels
)
print("Number of training samples: %d" % tf.data.experimental.cardinality(train_ds))
print(
"Number of validation samples: %d" % tf.data.experimental.cardinality(validation_ds)
)
print("Number of test samples: %d" % tf.data.experimental.cardinality(test_ds))
Number of training samples: 9305 Number of validation samples: 2326 Number of test samples: 2326
प्रशिक्षण डेटासेट में ये पहली 9 छवियां हैं - जैसा कि आप देख सकते हैं, ये सभी अलग-अलग आकार के हैं।
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
for i, (image, label) in enumerate(train_ds.take(9)):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(image)
plt.title(int(label))
plt.axis("off")
हम यह भी देख सकते हैं कि लेबल 1 "कुत्ता" है और लेबल 0 "बिल्ली" है।
डेटा का मानकीकरण
हमारी कच्ची छवियों में कई प्रकार के आकार होते हैं। इसके अलावा, प्रत्येक पिक्सेल में 0 और 255 (RGB स्तर मान) के बीच 3 पूर्णांक मान होते हैं। यह तंत्रिका नेटवर्क को खिलाने के लिए बहुत उपयुक्त नहीं है। हमें 2 चीजें करने की जरूरत है:
- एक निश्चित छवि आकार के लिए मानकीकृत करें। हम 150x150 चुनते हैं।
- सामान्य पिक्सेल मूल्यों के बीच -1 और 1. हम एक का उपयोग कर यह करूँगा
Normalization
मॉडल खुद के हिस्से के रूप परत।
सामान्य तौर पर, कच्चे डेटा को इनपुट के रूप में लेने वाले मॉडल विकसित करना एक अच्छा अभ्यास है, जो पहले से संसाधित डेटा लेने वाले मॉडल के विपरीत है। इसका कारण यह है कि, यदि आपका मॉडल प्रीप्रोसेस्ड डेटा की अपेक्षा करता है, तो जब भी आप अपने मॉडल को कहीं और उपयोग करने के लिए निर्यात करते हैं (वेब ब्राउज़र में, मोबाइल ऐप में), तो आपको ठीक उसी प्रीप्रोसेसिंग पाइपलाइन को फिर से लागू करने की आवश्यकता होगी। यह बहुत जल्दी बहुत मुश्किल हो जाता है। इसलिए हमें मॉडल को हिट करने से पहले कम से कम संभव मात्रा में प्रीप्रोसेसिंग करनी चाहिए।
यहां, हम डेटा पाइपलाइन में इमेज रीसाइज़िंग करेंगे (क्योंकि एक डीप न्यूरल नेटवर्क केवल डेटा के सन्निहित बैचों को प्रोसेस कर सकता है), और जब हम इसे बनाते हैं, तो हम मॉडल के हिस्से के रूप में इनपुट वैल्यू स्केलिंग करेंगे।
आइए छवियों का आकार बदलकर 150x150 करें:
size = (150, 150)
train_ds = train_ds.map(lambda x, y: (tf.image.resize(x, size), y))
validation_ds = validation_ds.map(lambda x, y: (tf.image.resize(x, size), y))
test_ds = test_ds.map(lambda x, y: (tf.image.resize(x, size), y))
इसके अलावा, आइए डेटा को बैच करें और लोडिंग गति को अनुकूलित करने के लिए कैशिंग और प्रीफ़ेचिंग का उपयोग करें।
batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)
यादृच्छिक डेटा वृद्धि का उपयोग करना
जब आपके पास एक बड़ा छवि डेटासेट नहीं होता है, तो प्रशिक्षण छवियों में यादृच्छिक लेकिन यथार्थवादी परिवर्तनों को लागू करके नमूना विविधता को कृत्रिम रूप से पेश करना एक अच्छा अभ्यास है, जैसे कि यादृच्छिक क्षैतिज फ़्लिपिंग या छोटे यादृच्छिक घुमाव। यह ओवरफिटिंग को धीमा करते हुए मॉडल को प्रशिक्षण डेटा के विभिन्न पहलुओं को उजागर करने में मदद करता है।
from tensorflow import keras
from tensorflow.keras import layers
data_augmentation = keras.Sequential(
[layers.RandomFlip("horizontal"), layers.RandomRotation(0.1),]
)
आइए कल्पना करें कि विभिन्न यादृच्छिक परिवर्तनों के बाद पहले बैच की पहली छवि कैसी दिखती है:
import numpy as np
for images, labels in train_ds.take(1):
plt.figure(figsize=(10, 10))
first_image = images[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(
tf.expand_dims(first_image, 0), training=True
)
plt.imshow(augmented_image[0].numpy().astype("int32"))
plt.title(int(labels[0]))
plt.axis("off")
2021-09-01 18:45:34.772284: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
मॉडल बनाना
आइए अब एक मॉडल बनाते हैं जो उस ब्लूप्रिंट का अनुसरण करता है जिसे हमने पहले समझाया था।
ध्यान दें कि:
- हम एक जोड़ने
Rescaling
पैमाने इनपुट मानों को परत (शुरू में[0, 255]
के लिए रेंज)[-1, 1]
रेंज। - हम एक जोड़ने
Dropout
नियमितीकरण के लिए वर्गीकरण परत से पहले परत,। - हम पारित करने के लिए सुनिश्चित करें कि
training=False
इतना है कि यह, अनुमान मोड में चलाता है ताकि batchnorm आंकड़ों को अपडेट नहीं मिलता के बाद भी हम ठीक करने के लिए बेस मॉडल नरम कर देना जब बेस मॉडल बुला,।
base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(150, 150, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs) # Apply random data augmentation
# Pre-trained Xception weights requires that input be scaled
# from (0, 255) to a range of (-1., +1.), the rescaling layer
# outputs: `(inputs * scale) + offset`
scale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)
x = scale_layer(x)
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)
model.summary()
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5 83689472/83683744 [==============================] - 2s 0us/step 83697664/83683744 [==============================] - 2s 0us/step Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 150, 150, 3)] 0 _________________________________________________________________ sequential_3 (Sequential) (None, 150, 150, 3) 0 _________________________________________________________________ rescaling (Rescaling) (None, 150, 150, 3) 0 _________________________________________________________________ xception (Functional) (None, 5, 5, 2048) 20861480 _________________________________________________________________ global_average_pooling2d (Gl (None, 2048) 0 _________________________________________________________________ dropout (Dropout) (None, 2048) 0 _________________________________________________________________ dense_7 (Dense) (None, 1) 2049 ================================================================= Total params: 20,863,529 Trainable params: 2,049 Non-trainable params: 20,861,480 _________________________________________________________________
शीर्ष परत को प्रशिक्षित करें
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 20
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Epoch 1/20 151/291 [==============>...............] - ETA: 3s - loss: 0.1979 - binary_accuracy: 0.9096 Corrupt JPEG data: 65 extraneous bytes before marker 0xd9 268/291 [==========================>...] - ETA: 1s - loss: 0.1663 - binary_accuracy: 0.9269 Corrupt JPEG data: 239 extraneous bytes before marker 0xd9 282/291 [============================>.] - ETA: 0s - loss: 0.1628 - binary_accuracy: 0.9284 Corrupt JPEG data: 1153 extraneous bytes before marker 0xd9 Corrupt JPEG data: 228 extraneous bytes before marker 0xd9 291/291 [==============================] - ETA: 0s - loss: 0.1620 - binary_accuracy: 0.9286 Corrupt JPEG data: 2226 extraneous bytes before marker 0xd9 291/291 [==============================] - 29s 63ms/step - loss: 0.1620 - binary_accuracy: 0.9286 - val_loss: 0.0814 - val_binary_accuracy: 0.9686 Epoch 2/20 291/291 [==============================] - 8s 29ms/step - loss: 0.1178 - binary_accuracy: 0.9511 - val_loss: 0.0785 - val_binary_accuracy: 0.9695 Epoch 3/20 291/291 [==============================] - 9s 30ms/step - loss: 0.1121 - binary_accuracy: 0.9536 - val_loss: 0.0748 - val_binary_accuracy: 0.9712 Epoch 4/20 291/291 [==============================] - 9s 29ms/step - loss: 0.1082 - binary_accuracy: 0.9554 - val_loss: 0.0754 - val_binary_accuracy: 0.9703 Epoch 5/20 291/291 [==============================] - 8s 29ms/step - loss: 0.1034 - binary_accuracy: 0.9570 - val_loss: 0.0721 - val_binary_accuracy: 0.9725 Epoch 6/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0975 - binary_accuracy: 0.9602 - val_loss: 0.0748 - val_binary_accuracy: 0.9699 Epoch 7/20 291/291 [==============================] - 9s 29ms/step - loss: 0.0989 - binary_accuracy: 0.9595 - val_loss: 0.0732 - val_binary_accuracy: 0.9716 Epoch 8/20 291/291 [==============================] - 8s 29ms/step - loss: 0.1027 - binary_accuracy: 0.9566 - val_loss: 0.0787 - val_binary_accuracy: 0.9678 Epoch 9/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0959 - binary_accuracy: 0.9614 - val_loss: 0.0734 - val_binary_accuracy: 0.9729 Epoch 10/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0995 - binary_accuracy: 0.9588 - val_loss: 0.0717 - val_binary_accuracy: 0.9721 Epoch 11/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0957 - binary_accuracy: 0.9612 - val_loss: 0.0731 - val_binary_accuracy: 0.9725 Epoch 12/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0936 - binary_accuracy: 0.9622 - val_loss: 0.0751 - val_binary_accuracy: 0.9716 Epoch 13/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0965 - binary_accuracy: 0.9610 - val_loss: 0.0821 - val_binary_accuracy: 0.9695 Epoch 14/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0939 - binary_accuracy: 0.9618 - val_loss: 0.0742 - val_binary_accuracy: 0.9712 Epoch 15/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0974 - binary_accuracy: 0.9585 - val_loss: 0.0771 - val_binary_accuracy: 0.9712 Epoch 16/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9621 - val_loss: 0.0823 - val_binary_accuracy: 0.9699 Epoch 17/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0947 - binary_accuracy: 0.9625 - val_loss: 0.0718 - val_binary_accuracy: 0.9708 Epoch 18/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0928 - binary_accuracy: 0.9616 - val_loss: 0.0738 - val_binary_accuracy: 0.9716 Epoch 19/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0922 - binary_accuracy: 0.9644 - val_loss: 0.0743 - val_binary_accuracy: 0.9716 Epoch 20/20 291/291 [==============================] - 8s 29ms/step - loss: 0.0885 - binary_accuracy: 0.9635 - val_loss: 0.0745 - val_binary_accuracy: 0.9695 <keras.callbacks.History at 0x7f849a3b2950>
संपूर्ण मॉडल की फ़ाइन-ट्यूनिंग का एक दौर करें
अंत में, आइए बेस मॉडल को अनफ्रीज करें और कम सीखने की दर के साथ पूरे मॉडल को एंड-टू-एंड प्रशिक्षित करें।
महत्वपूर्ण बात है, हालांकि बेस मॉडल trainable हो जाता है, यह अभी भी अनुमान मोड में चल रहा क्योंकि हम पारित कर दिया training=False
जब यह बुला जब हम मॉडल बनाया। इसका मतलब है कि अंदर बैच सामान्यीकरण परतें अपने बैच के आंकड़ों को अपडेट नहीं करेंगी। यदि उन्होंने ऐसा किया, तो वे मॉडल द्वारा अब तक सीखे गए अभ्यावेदन पर कहर बरपाएंगे।
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 10
model.fit(train_ds, epochs=epochs, validation_data=validation_ds)
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_5 (InputLayer) [(None, 150, 150, 3)] 0 _________________________________________________________________ sequential_3 (Sequential) (None, 150, 150, 3) 0 _________________________________________________________________ rescaling (Rescaling) (None, 150, 150, 3) 0 _________________________________________________________________ xception (Functional) (None, 5, 5, 2048) 20861480 _________________________________________________________________ global_average_pooling2d (Gl (None, 2048) 0 _________________________________________________________________ dropout (Dropout) (None, 2048) 0 _________________________________________________________________ dense_7 (Dense) (None, 1) 2049 ================================================================= Total params: 20,863,529 Trainable params: 20,809,001 Non-trainable params: 54,528 _________________________________________________________________ Epoch 1/10 291/291 [==============================] - 43s 131ms/step - loss: 0.0802 - binary_accuracy: 0.9692 - val_loss: 0.0580 - val_binary_accuracy: 0.9764 Epoch 2/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0542 - binary_accuracy: 0.9792 - val_loss: 0.0529 - val_binary_accuracy: 0.9764 Epoch 3/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0400 - binary_accuracy: 0.9832 - val_loss: 0.0510 - val_binary_accuracy: 0.9798 Epoch 4/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0313 - binary_accuracy: 0.9879 - val_loss: 0.0505 - val_binary_accuracy: 0.9819 Epoch 5/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0272 - binary_accuracy: 0.9904 - val_loss: 0.0485 - val_binary_accuracy: 0.9807 Epoch 6/10 291/291 [==============================] - 37s 128ms/step - loss: 0.0284 - binary_accuracy: 0.9901 - val_loss: 0.0497 - val_binary_accuracy: 0.9824 Epoch 7/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0198 - binary_accuracy: 0.9937 - val_loss: 0.0530 - val_binary_accuracy: 0.9802 Epoch 8/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0173 - binary_accuracy: 0.9930 - val_loss: 0.0572 - val_binary_accuracy: 0.9819 Epoch 9/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0113 - binary_accuracy: 0.9958 - val_loss: 0.0555 - val_binary_accuracy: 0.9837 Epoch 10/10 291/291 [==============================] - 37s 127ms/step - loss: 0.0091 - binary_accuracy: 0.9966 - val_loss: 0.0596 - val_binary_accuracy: 0.9832 <keras.callbacks.History at 0x7f83982d4cd0>
10 युगों के बाद, फाइन-ट्यूनिंग से हमें यहाँ एक अच्छा सुधार मिला है।