Web ワーカーを使用してモデルをトレーニングする,Web ワーカーを使用してモデルをトレーニングする

このチュートリアルでは、 Web ワーカーを使用してリカレント ニューラル ネットワーク(RNN) をトレーニングして整数の加算を行うサンプル Web アプリケーションを検討します。サンプル アプリでは、加算演算子を明示的に定義していません。代わりに、サンプルの合計を使用して RNN をトレーニングします。

もちろん、これは 2 つの整数を加算する最も効率的な方法ではありません。ただし、このチュートリアルでは、Web ML における重要なテクニック、つまり UI ロジックを処理するメインスレッドをブロックせずに長時間実行される計算を実行する方法を示しています。

このチュートリアルのサンプル アプリケーションはオンラインで入手できるため、コードをダウンロードしたり、開発環境をセットアップしたりする必要はありません。コードをローカルで実行する場合は、 「サンプルをローカルで実行する」のオプションの手順を実行します。開発環境をセットアップしたくない場合は、 「サンプルを調べる」に進んでください。

サンプルコードはGitHubで入手できます。

(オプション) サンプルをローカルで実行する

前提条件

サンプル アプリをローカルで実行するには、開発環境に次のものをインストールする必要があります。

サンプルアプリをインストールして実行する

  1. tfjs-examplesリポジトリのクローンを作成するか、ダウンロードします。
  2. addition-rnn-webworkerディレクトリに移動します。

    cd tfjs-examples/addition-rnn-webworker
    
  3. 依存関係をインストールします。

    yarn
    
  4. 開発サーバーを起動します。

    yarn run watch
    

例を見てみる

サンプルアプリを開きます。 (または、サンプルをローカルで実行している場合は、ブラウザでhttp://localhost:1234に移動します。)

TensorFlow.js: Addition RNNというタイトルのページが表示されるはずです。指示に従ってアプリを試してください。

Web フォームを使用すると、モデルのトレーニングに使用される次のようなパラメーターの一部を更新できます。

  • 桁数: 追加される項の最大桁数。
  • トレーニング サイズ: 生成するトレーニング サンプルの数。
  • RNN タイプ: SimpleRNNGRU 、またはLSTMのいずれか。
  • RNN Hidden Layer Size : 出力空間の次元数 (正の整数である必要があります)。
  • バッチ サイズ: 勾配更新ごとのサンプル数。
  • Train Iterations : model.fit()を呼び出してモデルをトレーニングする回数
  • # of test Example : 生成するサンプル文字列の数 (例: 27+41 )。

さまざまなパラメーターを使用してモデルをトレーニングしてみて、さまざまな数値セットの予測の精度を向上できるかどうかを確認してください。また、モデルの適合時間がさまざまなパラメーターによってどのように影響されるかにも注目してください。

コードを調べる

サンプル アプリでは、RNN をトレーニングするために構成できるパラメーターのいくつかを示します。また、Web ワーカーを使用してメインスレッドからモデルをトレーニングする方法も示します。 Web ワーカーは、計算量の多いトレーニング タスクをバックグラウンド スレッドで実行できるため、Web ML で重要です。これにより、メイン スレッドでユーザーに影響を与える可能性のあるパフォーマンスの問題を回避できます。メイン スレッドとワーカー スレッドは、メッセージ イベントを通じて相互に通信します。

Web ワーカーの詳細については、 「Web ワーカー API 」および「Web ワーカーの使用」を参照してください。

サンプルアプリのメインモジュールは、 index.jsです。 index.jsスクリプトは、 worker.jsモジュールを実行するWeb ワーカーを作成します

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.jsは主に、フォームの送信を処理し、フォーム データを処理し、ワーカーにフォーム データを渡し、ワーカーがモデルをトレーニングして結果を返すのを待ち、結果をページに表示する単一の関数runAdditionRNNDemoで構成されています。 。

フォーム データをワーカーに送信するために、スクリプトはワーカー上でpostMessage呼び出します

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

ワーカーはこのメッセージをリッスンし、データを準備してトレーニングを開始する関数にフォーム データを渡します。

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

トレーニング中に、ワーカーは 2 つの異なるメッセージ タイプを送信できます。1つはisPredicttrueに設定されています。

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

もう 1 つはisPredictfalseに設定されています。

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

UI スレッド ( index.js ) はメッセージ イベントを処理するときに、 isPredictフラグをチェックしてワーカーから返されたデータの形式を決定します。 isPredictが true の場合、データは予測を表す必要があり、スクリプトはtfjs-visを使用してページを更新しますisPredictが false の場合、スクリプトはデータが例を表していると想定してコード ブロックを実行します。データを HTML でラップし、その HTML をページに挿入します。

次は何ですか

このチュートリアルでは、長時間実行されるトレーニング プロセスによる UI スレッドのブロックを回避するために Web ワーカーを使用する例を示しました。バックグラウンド スレッドで負荷の高い計算を実行する利点の詳細については、 「Web ワーカーを使用してブラウザのメイン スレッドから JavaScript を実行する」を参照してください。

TensorFlow.js モデルのトレーニングの詳細については、 「モデルのトレーニング」を参照してください。