Sử dụng một mô hình được đào tạo trước

Trong hướng dẫn này, bạn sẽ khám phá một ứng dụng web mẫu thể hiện việc học chuyển giao bằng cách sử dụng API lớp TensorFlow.js. Ví dụ này tải một mô hình được đào tạo trước và sau đó đào tạo lại mô hình đó trong trình duyệt.

Mô hình đã được đào tạo trước bằng Python trên các chữ số 0-4 của tập dữ liệu phân loại chữ số MNIST . Việc đào tạo lại (hoặc học chuyển tiếp) trong trình duyệt sử dụng các chữ số 5-9. Ví dụ cho thấy rằng một số lớp đầu tiên của mô hình được đào tạo trước có thể được sử dụng để trích xuất các tính năng từ dữ liệu mới trong quá trình học chuyển, do đó cho phép đào tạo dữ liệu mới nhanh hơn.

Ứng dụng ví dụ cho hướng dẫn này có sẵn trực tuyến , do đó bạn không cần tải xuống bất kỳ mã nào hoặc thiết lập môi trường phát triển. Nếu bạn muốn chạy mã cục bộ, hãy hoàn thành các bước tùy chọn trong Chạy ví dụ cục bộ . Nếu bạn không muốn thiết lập môi trường phát triển, bạn có thể bỏ qua phần Khám phá ví dụ .

Mã ví dụ có sẵn trên GitHub .

(Tùy chọn) Chạy ví dụ cục bộ

Điều kiện tiên quyết

Để chạy ứng dụng mẫu cục bộ, bạn cần cài đặt những thứ sau trong môi trường phát triển của mình:

Cài đặt và chạy ứng dụng mẫu

  1. Sao chép hoặc tải xuống kho lưu trữ tfjs-examples .
  2. Thay đổi vào thư mục mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Cài đặt phụ thuộc:

    yarn
    
  4. Khởi động máy chủ phát triển:

    yarn run watch
    

Khám phá ví dụ

Mở ứng dụng ví dụ . (Hoặc, nếu bạn đang chạy ví dụ cục bộ, hãy truy cập http://localhost:1234 trong trình duyệt của bạn.)

Bạn sẽ thấy một trang có tiêu đề MNIST CNN Transfer Learning . Làm theo hướng dẫn để dùng thử ứng dụng.

Dưới đây là một số điều cần thử:

  • Thử nghiệm với các chế độ luyện tập khác nhau và so sánh độ mất và độ chính xác.
  • Chọn các ví dụ bitmap khác nhau và kiểm tra xác suất phân loại. Lưu ý rằng các số trong mỗi ví dụ bitmap là các giá trị số nguyên thang độ xám biểu thị các pixel từ một hình ảnh.
  • Chỉnh sửa các giá trị số nguyên bitmap và xem những thay đổi này ảnh hưởng như thế nào đến xác suất phân loại.

Khám phá mã

Ứng dụng web mẫu tải một mô hình đã được đào tạo trước trên một tập hợp con của tập dữ liệu MNIST. Việc đào tạo trước được xác định trong chương trình Python: mnist_transfer_cnn.py . Chương trình Python nằm ngoài phạm vi của hướng dẫn này, nhưng bạn nên xem xét nếu muốn xem ví dụ về chuyển đổi mô hình .

Tệp index.js chứa hầu hết mã đào tạo cho bản demo. Khi index.js chạy trong trình duyệt, một hàm thiết lập setupMnistTransferCNN sẽ khởi tạo và khởi tạo MnistTransferCNNPredictor , hàm này gói gọn các quy trình đào tạo lại và dự đoán.

Phương thức khởi tạo, MnistTransferCNNPredictor.init , tải mô hình, tải dữ liệu đào tạo lại và tạo dữ liệu thử nghiệm. Đây là dòng tải mô hình:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Nếu bạn nhìn vào định nghĩa của loader.loadHostedPretrainedModel , bạn sẽ thấy rằng nó trả về kết quả của lệnh gọi tới tf.loadLayersModel . Đây là API TensorFlow.js để tải mô hình bao gồm các đối tượng Lớp.

Logic đào tạo lại được xác định trong MnistTransferCNNPredictor.retrainModel . Nếu người dùng đã chọn Đóng băng các lớp tính năng làm chế độ đào tạo thì 7 lớp đầu tiên của mô hình cơ sở sẽ bị đóng băng và chỉ 5 lớp cuối cùng được đào tạo trên dữ liệu mới. Nếu người dùng đã chọn Khởi tạo lại trọng số thì tất cả trọng số sẽ được đặt lại và ứng dụng sẽ huấn luyện mô hình từ đầu một cách hiệu quả.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Sau đó, mô hình sẽ được biên dịchhuấn luyện dựa trên dữ liệu thử nghiệm bằng cách sử dụng model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Để tìm hiểu thêm về các tham số model.fit() , hãy xem tài liệu API .

Sau khi được huấn luyện về tập dữ liệu mới (chữ số 5-9), mô hình có thể được sử dụng để đưa ra dự đoán. Phương thức MnistTransferCNNPredictor.predict thực hiện điều này bằng cách sử dụng model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Lưu ý việc sử dụng tf.tidy , giúp ngăn ngừa rò rỉ bộ nhớ.

Tìm hiểu thêm

Hướng dẫn này đã khám phá một ứng dụng mẫu thực hiện việc học chuyển giao trong trình duyệt bằng TensorFlow.js. Hãy xem các tài nguyên bên dưới để tìm hiểu thêm về các mô hình được đào tạo trước và học tập chuyển giao.

TensorFlow.js

Lõi TenorFlow