Migrazione da TF1 a TF2 con TensorFlow Hub

Questa pagina spiega come continuare a utilizzare TensorFlow Hub durante la migrazione del codice TensorFlow da TensorFlow 1 a TensorFlow 2. Integra la guida generale alla migrazione di TensorFlow.

Per TF2, TF Hub si è allontanato dall'API legacy hub.Module per creare un tf.compat.v1.Graph come fa tf.contrib.v1.layers . Invece, ora c'è un hub.KerasLayer da utilizzare insieme ad altri livelli Keras per costruire un tf.keras.Model (tipicamente nel nuovo ambiente di esecuzione entusiasta di TF2) e il suo metodo hub.load() sottostante per il codice TensorFlow di basso livello.

L'API hub.Module rimane disponibile nella libreria tensorflow_hub per l'uso in TF1 e nella modalità di compatibilità TF1 di TF2. Può caricare solo modelli nel formato TF1 Hub .

La nuova API di hub.load() e hub.KerasLayer funziona per TensorFlow 1.15 (in modalità entusiasta e grafico) e in TensorFlow 2. Questa nuova API può caricare le nuove risorse TF2 SavedModel e, con le restrizioni stabilite nel modello guida alla compatibilità , i modelli legacy in formato TF1 Hub.

In generale, si consiglia di utilizzare la nuova API ove possibile.

Riepilogo della nuova API

hub.load() è la nuova funzione di basso livello per caricare un SavedModel da TensorFlow Hub (o servizi compatibili). Avvolge tf.saved_model.load() di TF2; La Guida SavedModel di TensorFlow descrive cosa puoi fare con il risultato.

m = hub.load(handle)
outputs = m(inputs)

La classe hub.KerasLayer chiama hub.load() e adatta il risultato per l'uso in Keras insieme ad altri livelli Keras. (Potrebbe anche essere un comodo wrapper per i SavedModels caricati utilizzati in altri modi.)

model = tf.keras.Sequential([
    hub.KerasLayer(handle),
    ...])

Molti tutorial mostrano queste API in azione. Ecco alcuni esempi:

Utilizzo della nuova API nella formazione Estimator

Se utilizzi un TF2 SavedModel in un Estimator per l'addestramento con server di parametri (o altrimenti in una sessione TF1 con variabili posizionate su dispositivi remoti), devi impostare experimental.share_cluster_devices_in_session nel ConfigProto di tf.Session, altrimenti riceverai un errore come "Il dispositivo assegnato '/job:ps/replica:0/task:0/device:CPU:0' non corrisponde ad alcun dispositivo."

L'opzione necessaria può essere impostata come

session_config = tf.compat.v1.ConfigProto()
session_config.experimental.share_cluster_devices_in_session = True
run_config = tf.estimator.RunConfig(..., session_config=session_config)
estimator = tf.estimator.Estimator(..., config=run_config)

A partire da TF2.2, questa opzione non è più sperimentale e il pezzo .experimental può essere eliminato.

Caricamento di modelli legacy nel formato TF1 Hub

Può succedere che un nuovo TF2 SavedModel non sia ancora disponibile per il tuo caso d'uso e sia necessario caricare un modello legacy nel formato TF1 Hub. A partire dalla versione 0.7 tensorflow_hub , puoi utilizzare il modello legacy nel formato TF1 Hub insieme a hub.KerasLayer come mostrato di seguito:

m = hub.KerasLayer(handle)
tensor_out = m(tensor_in)

Inoltre KerasLayer offre la possibilità di specificare tags , signature , output_key e signature_outputs_as_dict per utilizzi più specifici di modelli legacy nel formato TF1 Hub e SavedModels legacy.

Per ulteriori informazioni sulla compatibilità del formato TF1 Hub consultare la guida alla compatibilità dei modelli .

Utilizzo di API di livello inferiore

I modelli in formato TF1 Hub legacy possono essere caricati tramite tf.saved_model.load . Invece di

# DEPRECATED: TensorFlow 1
m = hub.Module(handle, tags={"foo", "bar"})
tensors_out_dict = m(dict(x1=..., x2=...), signature="sig", as_dict=True)

si consiglia di utilizzare:

# TensorFlow 2
m = hub.load(path, tags={"foo", "bar"})
tensors_out_dict = m.signatures["sig"](x1=..., x2=...)

In questi esempi m.signatures è un comando di funzioni concrete di TensorFlow codificate da nomi di firma. La chiamata a tale funzione calcola tutti i suoi output, anche se inutilizzati. (Ciò è diverso dalla valutazione pigra della modalità grafico di TF1.)