ใช้แบบจำลองที่ผ่านการฝึกอบรมมาแล้ว

ในบทช่วยสอนนี้ คุณจะได้สำรวจตัวอย่างเว็บแอปพลิเคชันที่สาธิตการถ่ายโอนการเรียนรู้โดยใช้ TensorFlow.js Layers API ตัวอย่างจะโหลดโมเดลที่ได้รับการฝึกล่วงหน้า จากนั้นฝึกโมเดลใหม่ในเบราว์เซอร์

โมเดลนี้ได้รับการฝึกอบรมล่วงหน้าใน Python ด้วยตัวเลข 0-4 ของ ชุดข้อมูลการจำแนกประเภทหลัก MNIST การอบรมขึ้นใหม่ (หรือถ่ายโอนการเรียนรู้) ในเบราว์เซอร์ใช้ตัวเลข 5-9 ตัวอย่างแสดงให้เห็นว่าโมเดลที่ได้รับการฝึกอบรมล่วงหน้าหลายชั้นแรกสามารถใช้เพื่อแยกคุณสมบัติจากข้อมูลใหม่ในระหว่างการเรียนรู้การถ่ายโอน ซึ่งช่วยให้สามารถฝึกอบรมข้อมูลใหม่ได้เร็วขึ้น

แอปพลิเคชันตัวอย่างสำหรับบทช่วยสอนนี้มี ให้ใช้งานออนไลน์ ดังนั้นคุณไม่จำเป็นต้องดาวน์โหลดโค้ดใดๆ หรือตั้งค่าสภาพแวดล้อมการพัฒนา หากคุณต้องการเรียกใช้โค้ดในเครื่อง ให้ทำตามขั้นตอนเสริมใน เรียกใช้ตัวอย่างในเครื่อง หากคุณไม่ต้องการตั้งค่าสภาพแวดล้อมการพัฒนา คุณสามารถข้ามไปที่ สำรวจตัวอย่าง ได้

โค้ดตัวอย่างมีอยู่ใน GitHub

(ไม่บังคับ) เรียกใช้ตัวอย่างในเครื่อง

ข้อกำหนดเบื้องต้น

หากต้องการเรียกใช้แอปตัวอย่างในเครื่อง คุณต้องติดตั้งสิ่งต่อไปนี้ในสภาพแวดล้อมการพัฒนาของคุณ:

ติดตั้งและเรียกใช้แอปตัวอย่าง

  1. โคลนหรือดาวน์โหลดพื้นที่เก็บ tfjs-examples
  2. เปลี่ยนเป็นไดเร็กทอรี mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. ติดตั้งการพึ่งพา:

    yarn
    
  4. เริ่มเซิร์ฟเวอร์การพัฒนา:

    yarn run watch
    

สำรวจตัวอย่าง

เปิดแอปตัวอย่าง (หรือหากคุณใช้งานตัวอย่างในเครื่อง ให้ไปที่ http://localhost:1234 ในเบราว์เซอร์ของคุณ)

คุณควรเห็นหน้าที่ชื่อ MNIST CNN Transfer Learning ทำตามคำแนะนำเพื่อลองใช้แอป

ต่อไปนี้เป็นบางสิ่งที่ควรลอง:

  • ทดลองใช้โหมดการฝึกต่างๆ และเปรียบเทียบการสูญเสียและความแม่นยำ
  • เลือกตัวอย่างบิตแมปที่แตกต่างกัน และตรวจสอบความน่าจะเป็นในการจัดหมวดหมู่ โปรดทราบว่าตัวเลขในแต่ละตัวอย่างบิตแมปเป็นค่าจำนวนเต็มระดับสีเทาที่แสดงถึงพิกเซลจากรูปภาพ
  • แก้ไขค่าจำนวนเต็มบิตแมปและดูว่าการเปลี่ยนแปลงส่งผลต่อความน่าจะเป็นในการจำแนกประเภทอย่างไร

สำรวจรหัส

เว็บแอปตัวอย่างจะโหลดโมเดลที่ได้รับการฝึกอบรมล่วงหน้าในชุดย่อยของชุดข้อมูล MNIST การฝึกอบรมล่วงหน้าถูกกำหนดไว้ในโปรแกรม Python: mnist_transfer_cnn.py โปรแกรม Python อยู่นอกขอบเขตสำหรับบทช่วยสอนนี้ แต่ก็คุ้มค่าที่จะดูว่าคุณต้องการดูตัวอย่าง การแปลงโมเดล หรือไม่

ไฟล์ index.js มีโค้ดการฝึกอบรมส่วนใหญ่สำหรับการสาธิต เมื่อ index.js ทำงานในเบราว์เซอร์ ฟังก์ชันการตั้งค่า setupMnistTransferCNN จะสร้างอินสแตนซ์และเริ่มต้น MnistTransferCNNPredictor ซึ่งจะสรุปกิจวัตรการฝึกอบรมใหม่และการทำนาย

วิธีการเริ่มต้น MnistTransferCNNPredictor.init จะโหลดโมเดล โหลดข้อมูลการฝึกอบรมใหม่ และสร้างข้อมูลการทดสอบ นี่คือ บรรทัด ที่โหลดโมเดล:

this.model = await loader.loadHostedPretrainedModel(urls.model);

หากคุณดูคำจำกัดความของ loader.loadHostedPretrainedModel คุณจะเห็นว่ามันส่งคืนผลลัพธ์ของการเรียกไปที่ tf.loadLayersModel นี่คือ TensorFlow.js API สำหรับการโหลดโมเดลที่ประกอบด้วยออบเจ็กต์เลเยอร์

ตรรกะการฝึกอบรมขึ้นใหม่ถูกกำหนดไว้ใน MnistTransferCNNPredictor.retrainModel หากผู้ใช้เลือก ตรึงเลเยอร์คุณลักษณะ เป็นโหมดการฝึก 7 เลเยอร์แรกของโมเดลพื้นฐานจะถูกตรึง และเฉพาะ 5 เลเยอร์สุดท้ายเท่านั้นที่จะได้รับการฝึกกับข้อมูลใหม่ หากผู้ใช้เลือก กำหนด ค่าเริ่มต้นใหม่ น้ำหนัก น้ำหนักทั้งหมดจะถูกรีเซ็ต และแอปจะฝึกโมเดลตั้งแต่เริ่มต้นอย่างมีประสิทธิภาพ

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

จากนั้นโมเดลจะถูก คอมไพล์ จากนั้นจะ ถูกฝึก เกี่ยวกับข้อมูลทดสอบโดยใช้ model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับพารามิเตอร์ model.fit() โปรดดู เอกสารประกอบ API

หลังจากได้รับการฝึกอบรมเกี่ยวกับชุดข้อมูลใหม่ (ตัวเลข 5-9) แล้ว แบบจำลองจะสามารถนำมาใช้ในการคาดการณ์ได้ เมธอด MnistTransferCNNPredictor.predict ทำสิ่งนี้โดยใช้ model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

สังเกตการใช้ tf.tidy ซึ่งช่วยป้องกันหน่วยความจำรั่ว

เรียนรู้เพิ่มเติม

บทช่วยสอนนี้ได้สำรวจแอปตัวอย่างที่ทำการถ่ายโอนการเรียนรู้ในเบราว์เซอร์โดยใช้ TensorFlow.js ดูแหล่งข้อมูลด้านล่างเพื่อเรียนรู้เพิ่มเติมเกี่ยวกับโมเดลที่ได้รับการฝึกอบรมล่วงหน้าและถ่ายทอดการเรียนรู้

TensorFlow.js

แกนเทนเซอร์โฟลว์