TensorFlow GraphDef ベースのモデルを TensorFlow.js にインポートする

TensorFlow GraphDef ベースのモデル(通常は Python API で作成されたモデル)は、以下のいずれかの形式で保存することができます。

  1. TensorFlow SavedModel
  2. 凍結モデル
  3. Tensorflowハブモジュール

上記の形式はすべて、TensorFlow.js コンバータを使用して TensorFlow.js に直接読み込める形式に変換し、推論に利用することができます。

(注意: セッションバンドル形式は TensorFlow では非推奨なので、モデルを SavedModel 形式に移行してください。)

要件

変換プロシージャには Python 環境が必要です。これは pipenvvirtualenv を使用して隔離しておくことをお勧めします。コンバータのインストールには以下のコマンドを実行します。

 pip install tensorflowjs

TensorFlow モデルを TensorFlow.js にインポートするには、2 つのステップがあります。1 番目に既存のモデルを TensorFlow.js Web 形式に変換し、2 番目にそれを TensorFlow.js に読み込みます。

ステップ 1. 既存の TensorFlow モデルを TensorFlow.js Web 形式に変換する

pip パッケージで提供されているコンバータのスクリプトを実行します。

SavedModel の使用例。

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

凍結モデルの使用例:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Tensorflow Hub モジュールの使用例:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
位置指定引数 説明
input_path SavedModel ディレクトリ、セッションバンドルディレクトリ、凍結モデルファイル、あるいは TensorFlow Hub モジュールのハンドルまたはパスのフルパス。
output_path すべての中間生成物のパス。
オプション 説明
--input_format 入力モデルのフォーマットは、SavedModel には tf_saved_model、凍結モデルには tf_frozen_model、セッションバンドルには tf_session_bundle、TensorFlow Hub モジュールには tf_hub、Keras HDF5 には Keras を使用する。
--output_node_names カンマで区切られた出力ノードの名前。
--saved_model_tags SavedModel 変換のみに適用され、MetaGraphDef のタグを読み込む。カンマで区切られた形式で表示される。デフォルトはserve
--signature_name TensorFlow Hub モジュール変換のみに適用され、シグネチャを読み込む。デフォルトはdefault。詳細は https://www.tensorflow.org/hub/common_signatures/ を参照。

ヘルプメッセージの詳細を表示するには、以下のコマンドを使用します。

tensorflowjs_converter --help

コンバータで生成したファイル

上記の変換スクリプトは、2 種類のファイルを生成します。

  • model.json(データフローグラフと重みマニフェスト)
  • group1-shard\*of\*(バイナリ重みファイルのコレクション)

例えば、以下は MobileNet v2 を変換した場合の出力です。

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Step 2: ブラウザで読み込んで実行する

  1. 下記の tfjs-converter npm パッケージをインストールします。

yarn add @tensorflow/tfjs または npm install @tensorflow/tfjs

  1. 凍結モデルクラスをインスタンス化して推論を実行します。
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

MobileNet デモをご覧ください。

loadGraphModel API は追加のLoadOptionsパラメータを受け入れるため、これを使用して認証情報やカスタムヘッダをリクエストと共に送信することができます。詳細は loadGraphModel() ドキュメントをご覧ください。

サポートする演算

現在、TensorFlow.js がサポートする TensorFlow 演算は限られています。モデルがサポートされていない演算を使用している場合、tensorflowjs_converter スクリプトは失敗し、モデル内にあるサポートされていない演算のリストを出力します。各演算ごとに 1 つずつ issue を発行して、サポートが必要な演算を報告してください。

重みだけを読み込む

重みだけを読み込む場合には、以下のコードスニペットを使用します。

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");