در این آموزش شما یک نمونه برنامه وب را بررسی خواهید کرد که یادگیری انتقال را با استفاده از TensorFlow.js Layers API نشان می دهد. مثال یک مدل از پیش آموزش دیده را بارگیری می کند و سپس مدل را در مرورگر دوباره آموزش می دهد.
این مدل در پایتون روی ارقام 0-4 مجموعه داده طبقه بندی ارقام MNIST از قبل آموزش داده شده است. آموزش مجدد (یا انتقال یادگیری) در مرورگر از ارقام 5-9 استفاده می کند. این مثال نشان می دهد که چندین لایه اول یک مدل از پیش آموزش دیده می تواند برای استخراج ویژگی ها از داده های جدید در طول یادگیری انتقال استفاده شود، بنابراین امکان آموزش سریعتر روی داده های جدید فراهم می شود.
برنامه نمونه برای این آموزش به صورت آنلاین در دسترس است، بنابراین نیازی به دانلود کد یا راه اندازی یک محیط توسعه ندارید. اگر می خواهید کد را به صورت محلی اجرا کنید، مراحل اختیاری را در اجرای مثال به صورت محلی کامل کنید. اگر نمیخواهید یک محیط توسعه راهاندازی کنید، میتوانید به کاوش مثال بروید.
کد نمونه در GitHub موجود است.
(اختیاری) مثال را به صورت محلی اجرا کنید
پیش نیازها
برای اجرای برنامه نمونه به صورت محلی، باید موارد زیر را در محیط توسعه خود نصب کنید:
برنامه نمونه را نصب و اجرا کنید
- مخزن
tfjs-examples
را کلون یا دانلود کنید. به دایرکتوری
mnist-transfer-cnn
تغییر دهید:cd tfjs-examples/mnist-transfer-cnn
نصب وابستگی ها:
yarn
سرور توسعه را راه اندازی کنید:
yarn run watch
مثال را بررسی کنید
برنامه نمونه را باز کنید . (یا اگر مثال را به صورت محلی اجرا می کنید، به http://localhost:1234
در مرورگر خود بروید.)
شما باید صفحه ای با عنوان MNIST CNN Transfer Learning را ببینید. دستورالعمل ها را دنبال کنید تا برنامه را امتحان کنید.
در اینجا چند چیز برای امتحان وجود دارد:
- با حالتهای مختلف تمرین آزمایش کنید و ضرر و دقت را مقایسه کنید.
- نمونه های مختلف بیت مپ را انتخاب کنید و احتمالات طبقه بندی را بررسی کنید. توجه داشته باشید که اعداد در هر نمونه بیت مپ مقادیر صحیح مقیاس خاکستری هستند که پیکسل های یک تصویر را نشان می دهند.
- مقادیر عدد صحیح بیت مپ را ویرایش کنید و ببینید که چگونه تغییرات بر احتمالات طبقه بندی تاثیر می گذارد.
کد را کاوش کنید
برنامه وب مثال مدلی را بارگیری می کند که از قبل روی زیرمجموعه ای از مجموعه داده MNIST آموزش داده شده است. پیشآموزش در یک برنامه پایتون تعریف شده است: mnist_transfer_cnn.py
. برنامه پایتون برای این آموزش خارج از محدوده است، اما اگر میخواهید نمونهای از تبدیل مدل را ببینید، ارزش دیدن آن را دارد.
فایل index.js
شامل اکثر کدهای آموزشی نسخه ی نمایشی است. هنگامی که index.js
در مرورگر اجرا میشود، یک تابع راهاندازی، setupMnistTransferCNN
، MnistTransferCNNPredictor
را نمونهسازی و مقداردهی اولیه میکند، که روالهای بازآموزی و پیشبینی را محصور میکند.
روش اولیه، MnistTransferCNNPredictor.init
، یک مدل را بارگیری می کند، داده های بازآموزی را بارگیری می کند و داده های آزمایشی را ایجاد می کند. این خطی است که مدل را بارگذاری می کند:
this.model = await loader.loadHostedPretrainedModel(urls.model);
اگر به تعریف loader.loadHostedPretrainedModel
نگاه کنید، خواهید دید که نتیجه تماس به tf.loadLayersModel
را برمی گرداند. این API TensorFlow.js برای بارگیری مدلی متشکل از اشیاء لایه است.
منطق بازآموزی در MnistTransferCNNPredictor.retrainModel
تعریف شده است. اگر کاربر لایههای ویژگی Freeze را به عنوان حالت آموزشی انتخاب کرده باشد، 7 لایه اول مدل پایه ثابت میشوند و تنها 5 لایه نهایی بر روی دادههای جدید آموزش داده میشوند. اگر کاربر وزنهای اولیه را انتخاب کرده باشد، همه وزنها بازنشانی میشوند و برنامه به طور موثر مدل را از ابتدا آموزش میدهد.
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
سپس مدل کامپایل می شود و سپس با استفاده از model.fit()
بر روی داده های تست آموزش داده می شود:
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
برای کسب اطلاعات بیشتر در مورد پارامترهای model.fit()
به مستندات API مراجعه کنید.
پس از آموزش بر روی مجموعه داده جدید (رقم های 5-9)، این مدل می تواند برای پیش بینی استفاده شود. متد MnistTransferCNNPredictor.predict
این کار را با استفاده از model.predict()
انجام می دهد:
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
به استفاده از tf.tidy
توجه کنید که به جلوگیری از نشت حافظه کمک می کند.
بیشتر بدانید
این آموزش یک برنامه نمونه را بررسی کرده است که یادگیری انتقال را در مرورگر با استفاده از TensorFlow.js انجام می دهد. برای کسب اطلاعات بیشتر در مورد مدل های از پیش آموزش دیده و انتقال یادگیری، منابع زیر را بررسی کنید.
TensorFlow.js
- وارد کردن یک مدل Keras به TensorFlow.js
- یک مدل TensorFlow را به TensorFlow.js وارد کنید
- مدل های از پیش ساخته شده برای TensorFlow.js
هسته TensorFlow