I modelli basati su TensorFlow GraphDef (in genere creati tramite l'API Python) possono essere salvati in uno dei seguenti formati:
- Modello salvato di TensorFlow
- Modello congelato
- Modulo Tensorflow Hub
Tutti i formati di cui sopra possono essere convertiti dal convertitore TensorFlow.js in un formato che può essere caricato direttamente in TensorFlow.js per l'inferenza.
(Nota: TensorFlow ha deprecato il formato del bundle di sessione. Migra i tuoi modelli al formato SavedModel.)
Requisiti
La procedura di conversione richiede un ambiente Python; potresti voler mantenerne uno isolato usando pipenv o virtualenv .
Per installare il convertitore, eseguire il comando seguente:
pip install tensorflowjs
L'importazione di un modello TensorFlow in TensorFlow.js è un processo in due passaggi. Innanzitutto, converti un modello esistente nel formato web TensorFlow.js, quindi caricalo in TensorFlow.js.
Passaggio 1. Converti un modello TensorFlow esistente nel formato web TensorFlow.js
Esegui lo script del convertitore fornito dal pacchetto pip:
Esempio di modello salvato:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Esempio di modello congelato:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Esempio di modulo Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Argomenti posizionali | Descrizione |
---|---|
input_path | Percorso completo della directory del modello salvato, della directory del bundle di sessione, del file del modello congelato o dell'handle o del percorso del modulo TensorFlow Hub. |
output_path | Percorso per tutti gli artefatti di output. |
Opzioni | Descrizione |
---|---|
--input_format | Il formato del modello di input. Utilizza tf_saved_model per SavedModel, tf_frozen_model per il modello congelato, tf_session_bundle per il bundle di sessione, tf_hub per il modulo TensorFlow Hub e keras per Keras HDF5. |
--output_node_names | I nomi dei nodi di output, separati da virgole. |
--saved_model_tags | Applicabile solo alla conversione SavedModel. Tag del MetaGraphDef da caricare, in formato separato da virgole. Per impostazione predefinita viene serve . |
--signature_name | Applicabile solo alla conversione del modulo TensorFlow Hub, firma da caricare. Impostazioni predefinite per default . Vedere https://www.tensorflow.org/hub/common_signatures/ |
Utilizzare il seguente comando per ottenere un messaggio di aiuto dettagliato:
tensorflowjs_converter --help
File generati dal convertitore
Lo script di conversione sopra produce due tipi di file:
-
model.json
: il grafico del flusso di dati e il manifest del peso -
group1-shard\*of\*
: una raccolta di file di peso binari
Ad esempio, ecco l'output della conversione di MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Passaggio 2: caricamento ed esecuzione nel browser
- Installa il pacchetto tfjs-converter npm:
yarn add @tensorflow/tfjs
o npm install @tensorflow/tfjs
- Crea un'istanza della classe FrozenModel ed esegui l'inferenza.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Guarda la demo di MobileNet .
L'API loadGraphModel
accetta un parametro LoadOptions
aggiuntivo, che può essere utilizzato per inviare credenziali o intestazioni personalizzate insieme alla richiesta. Per i dettagli, consultare la documentazione loadGraphModel() .
Operazioni supportate
Attualmente TensorFlow.js supporta un set limitato di operazioni TensorFlow. Se il tuo modello utilizza un'operazione non supportata, lo script tensorflowjs_converter
fallirà e stamperà un elenco delle operazioni non supportate nel tuo modello. Invia un problema per ciascuna operazione per farci sapere per quali operazioni hai bisogno di supporto.
Caricamento solo dei pesi
Se preferisci caricare solo i pesi, puoi utilizzare il seguente snippet di codice:
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...