ในบทช่วยสอนนี้ คุณจะได้สำรวจตัวอย่างเว็บแอปพลิเคชันที่สาธิตการถ่ายโอนการเรียนรู้โดยใช้ TensorFlow.js Layers API ตัวอย่างจะโหลดโมเดลที่ได้รับการฝึกล่วงหน้า จากนั้นฝึกโมเดลใหม่ในเบราว์เซอร์
โมเดลนี้ได้รับการฝึกอบรมล่วงหน้าใน Python ด้วยตัวเลข 0-4 ของ ชุดข้อมูลการจำแนกประเภทหลัก MNIST การอบรมขึ้นใหม่ (หรือถ่ายโอนการเรียนรู้) ในเบราว์เซอร์ใช้ตัวเลข 5-9 ตัวอย่างแสดงให้เห็นว่าโมเดลที่ได้รับการฝึกอบรมล่วงหน้าหลายชั้นแรกสามารถใช้เพื่อแยกคุณสมบัติจากข้อมูลใหม่ในระหว่างการเรียนรู้การถ่ายโอน ซึ่งช่วยให้สามารถฝึกอบรมข้อมูลใหม่ได้เร็วขึ้น
แอปพลิเคชันตัวอย่างสำหรับบทช่วยสอนนี้มี ให้ใช้งานออนไลน์ ดังนั้นคุณไม่จำเป็นต้องดาวน์โหลดโค้ดใดๆ หรือตั้งค่าสภาพแวดล้อมการพัฒนา หากคุณต้องการเรียกใช้โค้ดในเครื่อง ให้ทำตามขั้นตอนเสริมใน เรียกใช้ตัวอย่างในเครื่อง หากคุณไม่ต้องการตั้งค่าสภาพแวดล้อมการพัฒนา คุณสามารถข้ามไปที่ สำรวจตัวอย่าง ได้
โค้ดตัวอย่างมีอยู่ใน GitHub
(ไม่บังคับ) เรียกใช้ตัวอย่างในเครื่อง
ข้อกำหนดเบื้องต้น
หากต้องการเรียกใช้แอปตัวอย่างในเครื่อง คุณต้องติดตั้งสิ่งต่อไปนี้ในสภาพแวดล้อมการพัฒนาของคุณ:
ติดตั้งและเรียกใช้แอปตัวอย่าง
- โคลนหรือดาวน์โหลดพื้นที่เก็บ
tfjs-examples
เปลี่ยนเป็นไดเร็กทอรี
mnist-transfer-cnn
:cd tfjs-examples/mnist-transfer-cnn
ติดตั้งการพึ่งพา:
yarn
เริ่มเซิร์ฟเวอร์การพัฒนา:
yarn run watch
สำรวจตัวอย่าง
เปิดแอปตัวอย่าง (หรือหากคุณใช้งานตัวอย่างในเครื่อง ให้ไปที่ http://localhost:1234
ในเบราว์เซอร์ของคุณ)
คุณควรเห็นหน้าที่ชื่อ MNIST CNN Transfer Learning ทำตามคำแนะนำเพื่อลองใช้แอป
ต่อไปนี้เป็นบางสิ่งที่ควรลอง:
- ทดลองใช้โหมดการฝึกต่างๆ และเปรียบเทียบการสูญเสียและความแม่นยำ
- เลือกตัวอย่างบิตแมปที่แตกต่างกัน และตรวจสอบความน่าจะเป็นในการจัดหมวดหมู่ โปรดทราบว่าตัวเลขในแต่ละตัวอย่างบิตแมปเป็นค่าจำนวนเต็มระดับสีเทาที่แสดงถึงพิกเซลจากรูปภาพ
- แก้ไขค่าจำนวนเต็มบิตแมปและดูว่าการเปลี่ยนแปลงส่งผลต่อความน่าจะเป็นในการจำแนกประเภทอย่างไร
สำรวจรหัส
เว็บแอปตัวอย่างจะโหลดโมเดลที่ได้รับการฝึกอบรมล่วงหน้าในชุดย่อยของชุดข้อมูล MNIST การฝึกอบรมล่วงหน้าถูกกำหนดไว้ในโปรแกรม Python: mnist_transfer_cnn.py
โปรแกรม Python อยู่นอกขอบเขตสำหรับบทช่วยสอนนี้ แต่ก็คุ้มค่าที่จะดูว่าคุณต้องการดูตัวอย่าง การแปลงโมเดล หรือไม่
ไฟล์ index.js
มีโค้ดการฝึกอบรมส่วนใหญ่สำหรับการสาธิต เมื่อ index.js
ทำงานในเบราว์เซอร์ ฟังก์ชันการตั้งค่า setupMnistTransferCNN
จะสร้างอินสแตนซ์และเริ่มต้น MnistTransferCNNPredictor
ซึ่งจะสรุปกิจวัตรการฝึกอบรมใหม่และการทำนาย
วิธีการเริ่มต้น MnistTransferCNNPredictor.init
จะโหลดโมเดล โหลดข้อมูลการฝึกอบรมใหม่ และสร้างข้อมูลการทดสอบ นี่คือ บรรทัด ที่โหลดโมเดล:
this.model = await loader.loadHostedPretrainedModel(urls.model);
หากคุณดูคำจำกัดความของ loader.loadHostedPretrainedModel
คุณจะเห็นว่ามันส่งคืนผลลัพธ์ของการเรียกไปที่ tf.loadLayersModel
นี่คือ TensorFlow.js API สำหรับการโหลดโมเดลที่ประกอบด้วยออบเจ็กต์เลเยอร์
ตรรกะการฝึกอบรมขึ้นใหม่ถูกกำหนดไว้ใน MnistTransferCNNPredictor.retrainModel
หากผู้ใช้เลือก ตรึงเลเยอร์คุณลักษณะ เป็นโหมดการฝึก 7 เลเยอร์แรกของโมเดลพื้นฐานจะถูกตรึง และเฉพาะ 5 เลเยอร์สุดท้ายเท่านั้นที่จะได้รับการฝึกกับข้อมูลใหม่ หากผู้ใช้เลือก กำหนด ค่าเริ่มต้นใหม่ น้ำหนัก น้ำหนักทั้งหมดจะถูกรีเซ็ต และแอปจะฝึกโมเดลตั้งแต่เริ่มต้นอย่างมีประสิทธิภาพ
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
จากนั้นโมเดลจะถูก คอมไพล์ จากนั้นจะ ถูกฝึก เกี่ยวกับข้อมูลทดสอบโดยใช้ model.fit()
:
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับพารามิเตอร์ model.fit()
โปรดดู เอกสารประกอบ API
หลังจากได้รับการฝึกอบรมเกี่ยวกับชุดข้อมูลใหม่ (ตัวเลข 5-9) แล้ว แบบจำลองจะสามารถนำมาใช้ในการคาดการณ์ได้ เมธอด MnistTransferCNNPredictor.predict
ทำสิ่งนี้โดยใช้ model.predict()
:
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
สังเกตการใช้ tf.tidy
ซึ่งช่วยป้องกันหน่วยความจำรั่ว
เรียนรู้เพิ่มเติม
บทช่วยสอนนี้ได้สำรวจแอปตัวอย่างที่ทำการถ่ายโอนการเรียนรู้ในเบราว์เซอร์โดยใช้ TensorFlow.js ดูแหล่งข้อมูลด้านล่างเพื่อเรียนรู้เพิ่มเติมเกี่ยวกับโมเดลที่ได้รับการฝึกอบรมล่วงหน้าและถ่ายทอดการเรียนรู้
TensorFlow.js
- การนำเข้าโมเดล Keras ไปยัง TensorFlow.js
- นำเข้าโมเดล TensorFlow ไปยัง TensorFlow.js
- โมเดลที่สร้างไว้ล่วงหน้าสำหรับ TensorFlow.js
แกนเทนเซอร์โฟลว์