Visualizza su TensorFlow.org | Esegui in Google Colab | Visualizza su GitHub | Scarica quaderno |
Panoramica
Questa guida fornisce un elenco di best practice per la scrittura di codice utilizzando TensorFlow 2 (TF2), è stata scritta per gli utenti che sono passati di recente da TensorFlow 1 (TF1). Fare riferimento alla sezione di migrazione della guida per ulteriori informazioni sulla migrazione del codice da TF1 a TF2.
Impostare
Importa TensorFlow e altre dipendenze per gli esempi in questa guida.
import tensorflow as tf
import tensorflow_datasets as tfds
Raccomandazioni per TensorFlow idiomatico 2
Refactoring del tuo codice in moduli più piccoli
Una buona pratica consiste nel refactoring del codice in funzioni più piccole che vengono chiamate secondo necessità. Per ottenere le migliori prestazioni, dovresti provare a decorare i blocchi di calcolo più grandi che puoi in una tf.function
(nota che le funzioni python nidificate chiamate da una tf.function
non richiedono le loro decorazioni separate, a meno che tu non voglia usare jit_compile
differenti impostazioni per la tf.function
). A seconda del tuo caso d'uso, potrebbero trattarsi di più fasi di allenamento o persino dell'intero ciclo di allenamento. Per i casi d'uso dell'inferenza, potrebbe essere un singolo passaggio in avanti del modello.
Regola il tasso di apprendimento predefinito per alcuni tf.keras.optimizer
s
Alcuni ottimizzatori Keras hanno tassi di apprendimento diversi in TF2. Se vedi un cambiamento nel comportamento di convergenza per i tuoi modelli, controlla i tassi di apprendimento predefiniti.
Non ci sono modifiche per optimizers.SGD
, optimizers.Adam
o optimizers.RMSprop
.
I seguenti tassi di apprendimento predefiniti sono cambiati:
-
optimizers.Adagrad
da0.01
a0.001
-
optimizers.Adadelta
da1.0
a0.001
-
optimizers.Adamax
da0.002
a0.001
-
optimizers.Nadam
da0.002
a0.001
Usa i livelli tf.Module
Keras per gestire le variabili
tf.Module
s e tf.keras.layers.Layer
s offrono le comode variables
e le proprietà trainable_variables
, che raccolgono ricorsivamente tutte le variabili dipendenti. Ciò semplifica la gestione delle variabili localmente nel punto in cui vengono utilizzate.
I livelli/modelli Keras ereditano da tf.train.Checkpointable
e sono integrati con @tf.function
, che consente di effettuare il checkpoint o esportare direttamente i modelli salvati da oggetti Keras. Non è necessario utilizzare l'API Model.fit
di Keras per sfruttare queste integrazioni.
Leggi la sezione sull'apprendimento del trasferimento e la messa a punto nella guida di Keras per imparare come raccogliere un sottoinsieme di variabili rilevanti usando Keras.
Combina tf.data.Dataset
tf.function
Il pacchetto TensorFlow Datasets ( tfds
) contiene utilità per caricare set di dati predefiniti come oggetti tf.data.Dataset
. Per questo esempio, puoi caricare il set di dati MNIST usando tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Quindi preparare i dati per l'allenamento:
- Ridimensiona ogni immagine.
- Mescola l'ordine degli esempi.
- Raccogliere lotti di immagini ed etichette.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Per mantenere l'esempio breve, ritaglia il set di dati per restituire solo 5 batch:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Usa l'iterazione Python regolare per eseguire l'iterazione sui dati di addestramento che si adattano alla memoria. In caso contrario, tf.data.Dataset
è il modo migliore per eseguire lo streaming dei dati di addestramento dal disco. I set di dati sono iterabili (non iteratori) e funzionano proprio come altri iterabili Python nell'esecuzione ansiosa. Puoi utilizzare completamente le funzionalità di precaricamento/streaming asincrono del set di dati avvolgendo il codice in tf.function
, che sostituisce l'iterazione Python con le operazioni del grafico equivalenti usando AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Se utilizzi l'API Keras Model.fit
, non dovrai preoccuparti dell'iterazione del set di dati.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Usa i cicli di allenamento Keras
Se non hai bisogno di un controllo di basso livello del tuo processo di allenamento, si consiglia di utilizzare i metodi integrati di Keras fit
, evaluate
e predict
. Questi metodi forniscono un'interfaccia uniforme per addestrare il modello indipendentemente dall'implementazione (sequenziale, funzionale o sottoclasse).
I vantaggi di questi metodi includono:
- Accettano array Numpy, generatori Python e
tf.data.Datasets
. - Applicano automaticamente la regolarizzazione e le perdite di attivazione.
- Supportano
tf.distribute
dove il codice di addestramento rimane lo stesso indipendentemente dalla configurazione hardware . - Supportano callable arbitrari come perdite e metriche.
- Supportano callback come
tf.keras.callbacks.TensorBoard
e callback personalizzate. - Sono performanti, utilizzando automaticamente i grafici TensorFlow.
Ecco un esempio di training di un modello utilizzando un Dataset
. Per i dettagli su come funziona, dai un'occhiata ai tutorial .
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938 Epoch 2/5 2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969 Epoch 3/5 2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469 Epoch 4/5 2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688 Epoch 5/5 2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719 2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781 Loss 1.4552843570709229, Accuracy 0.578125 2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Personalizza l'allenamento e scrivi il tuo ciclo
Se i modelli Keras funzionano per te, ma hai bisogno di maggiore flessibilità e controllo della fase di formazione o dei cicli di formazione esterni, puoi implementare le tue fasi di formazione o anche interi cicli di formazione. Consulta la guida Keras sulla personalizzazione della fit
per saperne di più.
Puoi anche implementare molte cose come tf.keras.callbacks.Callback
.
Questo metodo ha molti dei vantaggi menzionati in precedenza , ma ti dà il controllo del passaggio del treno e persino del ciclo esterno.
Ci sono tre passaggi per un ciclo di formazione standard:
- Esegui l'iterazione su un generatore Python o
tf.data.Dataset
per ottenere batch di esempi. - Usa
tf.GradientTape
per raccogliere i gradienti. - Utilizzare uno dei
tf.keras.optimizers
per applicare gli aggiornamenti di peso alle variabili del modello.
Ricorda:
- Includere sempre un argomento di
training
sul metodo dicall
di livelli e modelli di sottoclassi. - Assicurati di chiamare il modello con l'argomento di
training
impostato correttamente. - A seconda dell'utilizzo, le variabili del modello potrebbero non esistere finché il modello non viene eseguito su un batch di dati.
- Devi gestire manualmente cose come le perdite di regolarizzazione per il modello.
Non è necessario eseguire inizializzatori di variabili o aggiungere dipendenze di controllo manuali. tf.function
gestisce automaticamente le dipendenze di controllo e l'inizializzazione delle variabili durante la creazione.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Sfrutta tf.function
con il flusso di controllo Python
tf.function
fornisce un modo per convertire il flusso di controllo dipendente dai dati in equivalenti in modalità grafico come tf.cond
e tf.while_loop
.
Un luogo comune in cui appare il flusso di controllo dipendente dai dati è nei modelli di sequenza. tf.keras.layers.RNN
wrapping di una cella RNN, consentendo di svolgere la ricorrenza in modo statico o dinamico. Ad esempio, puoi reimplementare lo srotolamento dinamico come segue.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Leggere la guida tf.function
per ulteriori informazioni.
Metriche e perdite di nuovo stile
Metriche e perdite sono entrambi oggetti che funzionano avidamente e in tf.function
s.
Un oggetto di perdita è richiamabile e si aspetta ( y_true
, y_pred
) come argomenti:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Utilizza le metriche per raccogliere e visualizzare i dati
È possibile utilizzare tf.metrics
per aggregare i dati e tf.summary
per registrare i riepiloghi e reindirizzarli a uno scrittore utilizzando un gestore di contesto. I riepiloghi vengono inviati direttamente allo scrittore, il che significa che è necessario fornire il valore del step
sul sito di chiamata.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Utilizza tf.metrics
per aggregare i dati prima di registrarli come riepiloghi. Le metriche sono stateful; accumulano valori e restituiscono un risultato cumulativo quando si chiama il metodo result
(come Mean.result
). Cancella i valori accumulati con Model.reset_states
.
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Visualizza i riepiloghi generati puntando TensorBoard alla directory del registro riepilogativo:
tensorboard --logdir /tmp/summaries
Utilizza l'API tf.summary
per scrivere dati di riepilogo per la visualizzazione in TensorBoard. Per maggiori informazioni, leggi la guida tf.summary
.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.142 accuracy: 0.991 2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.125 accuracy: 0.997 2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.110 accuracy: 0.997 2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.099 accuracy: 0.997 Epoch: 4 loss: 0.085 accuracy: 1.000 2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Nomi delle metriche Keras
I modelli Keras sono coerenti nella gestione dei nomi delle metriche. Quando passi una stringa nell'elenco delle metriche, quella stringa esatta viene utilizzata come name
della metrica . Questi nomi sono visibili nell'oggetto cronologia restituito da model.fit
e nei log passati a keras.callbacks
. è impostato sulla stringa passata nell'elenco delle metriche.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969 2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Debug
Usa l'esecuzione desiderosa per eseguire il codice passo dopo passo per ispezionare forme, tipi di dati e valori. Alcune API, come tf.function
, tf.keras
, ecc. sono progettate per utilizzare l'esecuzione di Graph, per prestazioni e portabilità. Durante il debug, utilizzare tf.config.run_functions_eagerly(True)
per utilizzare l'esecuzione desiderosa all'interno di questo codice.
Per esempio:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
Funziona anche all'interno dei modelli Keras e di altre API che supportano l'esecuzione ansiosa:
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Appunti:
tf.keras.Model
metodi comefit
,evaluate
epredict
vengono eseguiti come grafici contf.function
sotto il cofano.Quando si utilizza
tf.keras.Model.compile
, impostarerun_eagerly = True
per disabilitare la logica delModel
dall'essere racchiusa in unatf.function
.Utilizzare
tf.data.experimental.enable_debug_mode
per abilitare la modalità di debug pertf.data
. Leggi i documenti API per maggiori dettagli.
Non tenere tf.Tensors
nei tuoi oggetti
Questi oggetti tensore potrebbero essere creati in una tf.function
o nel contesto desideroso e questi tensori si comportano in modo diverso. Utilizzare sempre tf.Tensor
s solo per valori intermedi.
Per tenere traccia dello stato, usa tf.Variable
s poiché sono sempre utilizzabili da entrambi i contesti. Leggi la guida tf.Variable
per saperne di più.
Risorse e ulteriori letture
Leggi le guide e i tutorial di TF2 per saperne di più su come utilizzare TF2.
Se in precedenza hai utilizzato TF1.x, si consiglia vivamente di migrare il codice a TF2. Leggi le guide alla migrazione per saperne di più.