事前トレーニング済みのモデルを使用する

このチュートリアルでは、TensorFlow.js レイヤー API を使用した転移学習を示すサンプル Web アプリケーションを検討します。この例では、事前トレーニングされたモデルをロードし、ブラウザーでモデルを再トレーニングします。

モデルは、MNIST 桁分類データセットの桁 0 ~ 4 で Python で事前トレーニングされています。ブラウザでの再トレーニング (または転移学習) では、5 ~ 9 の数字が使用されます。この例は、事前トレーニングされたモデルの最初のいくつかの層を使用して転移学習中に新しいデータから特徴を抽出できるため、新しいデータでのトレーニングを高速化できることを示しています。

このチュートリアルのサンプル アプリケーションはオンラインで入手できるため、コードをダウンロードしたり、開発環境をセットアップしたりする必要はありません。コードをローカルで実行する場合は、 「サンプルをローカルで実行する」のオプションの手順を実行します。開発環境をセットアップしたくない場合は、 「サンプルを調べる」に進んでください。

サンプルコードはGitHubで入手できます。

(オプション) サンプルをローカルで実行する

前提条件

サンプル アプリをローカルで実行するには、開発環境に次のものをインストールする必要があります。

サンプルアプリをインストールして実行する

  1. tfjs-examplesリポジトリのクローンを作成するか、ダウンロードします。
  2. mnist-transfer-cnnディレクトリに移動します。

    cd tfjs-examples/mnist-transfer-cnn
    
  3. 依存関係をインストールします。

    yarn
    
  4. 開発サーバーを起動します。

    yarn run watch
    

例を見てみる

サンプルアプリを開きます。 (または、サンプルをローカルで実行している場合は、ブラウザでhttp://localhost:1234に移動します。)

「 MNIST CNN Transfer Learning 」というタイトルのページが表示されるはずです。指示に従ってアプリを試してください。

以下のことを試してみてください。

  • さまざまなトレーニング モードを試して、損失と精度を比較してください。
  • さまざまなビットマップの例を選択し、分類確率を検査します。各ビットマップの例の数字は、画像のピクセルを表すグレースケールの整数値であることに注意してください。
  • ビットマップ整数値を編集し、その変更が分類確率にどのような影響を与えるかを確認します。

コードを調べる

サンプル Web アプリは、MNIST データセットのサブセットで事前トレーニングされたモデルを読み込みます。事前トレーニングは Python プログラムmnist_transfer_cnn.pyで定義されます。 Python プログラムはこのチュートリアルの範囲外ですが、モデル変換の例を確認したい場合は参照する価値があります。

index.jsファイルには、デモ用のトレーニング コードのほとんどが含まれています。 index.jsブラウザーで実行されると、セットアップ関数setupMnistTransferCNNが、再トレーニングおよび予測ルーチンをカプセル化するMnistTransferCNNPredictorをインスタンス化して初期化します。

初期化メソッドMnistTransferCNNPredictor.initは、モデルをロードし、再トレーニング データをロードし、テスト データを作成します。モデルをロードするは次のとおりです。

this.model = await loader.loadHostedPretrainedModel(urls.model);

loader.loadHostedPretrainedModelの定義を見ると、 tf.loadLayersModelへの呼び出しの結果を返すことがわかります。これは、Layer オブジェクトで構成されるモデルをロードするための TensorFlow.js API です。

再トレーニング ロジックはMnistTransferCNNPredictor.retrainModelで定義されます。ユーザーがトレーニング モードとして[フィーチャ レイヤーのフリーズ]を選択した場合、ベース モデルの最初の 7 レイヤーはフリーズされ、最後の 5 レイヤーのみが新しいデータでトレーニングされます。ユーザーが重みの再初期化を選択した場合、すべての重みがリセットされ、アプリは効果的にモデルを最初からトレーニングします。

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

次に、モデルがコンパイルされmodel.fit()を使用してテスト データでトレーニングされます

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

model.fit()パラメータの詳細については、 API ドキュメントを参照してください。

新しいデータセット (数字 5 ~ 9) でトレーニングされた後、モデルを使用して予測を行うことができます。 MnistTransferCNNPredictor.predictメソッドはmodel.predict()を使用してこれを行います。

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

メモリリークの防止に役立つtf.tidyの使用に注意してください。

もっと詳しく知る

このチュートリアルでは、TensorFlow.js を使用してブラウザーで転移学習を実行するサンプル アプリを検討しました。事前トレーニングされたモデルと転移学習について詳しくは、以下のリソースをご覧ください。

TensorFlow.js

TensorFlow コア