Модели на основе TensorFlow GraphDef (обычно создаваемые с помощью API Python) можно сохранить в одном из следующих форматов:
- Сохраненная модель TensorFlow
- Замороженная модель
- Модуль Tensorflow Hub
Все вышеперечисленные форматы можно преобразовать с помощью конвертера TensorFlow.js в формат, который можно загрузить непосредственно в TensorFlow.js для вывода.
(Примечание. В TensorFlow формат пакета сеансов устарел. Перенесите свои модели в формат SavedModel.)
Требования
Для процедуры преобразования требуется среда Python; возможно, вы захотите сохранить его изолированным, используя Pipenv или Virtualenv .
Чтобы установить конвертер, выполните следующую команду:
pip install tensorflowjs
Импорт модели TensorFlow в TensorFlow.js представляет собой двухэтапный процесс. Сначала преобразуйте существующую модель в веб-формат TensorFlow.js, а затем загрузите ее в TensorFlow.js.
Шаг 1. Преобразование существующей модели TensorFlow в веб-формат TensorFlow.js.
Запустите скрипт конвертера, предоставленный пакетом pip:
Пример сохраненной модели:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Пример замороженной модели:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Пример модуля Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Позиционные аргументы | Описание |
---|---|
input_path | Полный путь к каталогу сохраненной модели, каталогу пакета сеанса, файлу замороженной модели или дескриптору или пути модуля TensorFlow Hub. |
output_path | Путь ко всем выходным артефактам. |
Параметры | Описание |
---|---|
--input_format | Формат входной модели. Используйте tf_saved_model для SavedModel, tf_frozen_model для замороженной модели, tf_session_bundle для пакета сеанса, tf_hub для модуля TensorFlow Hub и keras для Keras HDF5. |
--output_node_names | Имена выходных узлов, разделенные запятыми. |
--saved_model_tags | Применимо только к преобразованию SavedModel. Теги MetaGraphDef для загрузки в формате, разделенном запятыми. По умолчанию для serve . |
--signature_name | Применимо только к преобразованию модуля TensorFlow Hub, подписи для загрузки. По умолчанию по default . См. https://www.tensorflow.org/hub/common_signatures/ . |
Используйте следующую команду, чтобы получить подробное справочное сообщение:
tensorflowjs_converter --help
Файлы, созданные конвертером
Приведенный выше сценарий преобразования создает файлы двух типов:
-
model.json
: график потока данных и манифест веса. -
group1-shard\*of\*
: Коллекция файлов двоичных весов.
Например, вот результат преобразования MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Шаг 2. Загрузка и запуск в браузере.
- Установите npm-пакет tfjs-converter:
yarn add @tensorflow/tfjs
или npm install @tensorflow/tfjs
- Создайте экземпляр класса FrozenModel и запустите вывод.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Посмотрите демо-версию MobileNet .
API loadGraphModel
принимает дополнительный параметр LoadOptions
, который можно использовать для отправки учетных данных или пользовательских заголовков вместе с запросом. Подробности смотрите в документации loadGraphModel() .
Поддерживаемые операции
В настоящее время TensorFlow.js поддерживает ограниченный набор операций TensorFlow. Если ваша модель использует неподдерживаемую операцию, сценарий tensorflowjs_converter
завершится ошибкой и распечатает список неподдерживаемых операций в вашей модели. Пожалуйста, сообщите о проблеме для каждой операции, чтобы сообщить нам, для каких операций вам нужна поддержка.
Загрузка только гирь
Если вы предпочитаете загружать только веса, вы можете использовать следующий фрагмент кода:
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...