Visualizza su TensorFlow.org | Esegui in Google Colab | Visualizza l'origine su GitHub | Scarica quaderno |
In TensorFlow 2, l'esecuzione desiderosa è attivata per impostazione predefinita. L'interfaccia utente è intuitiva e flessibile (l'esecuzione di operazioni una tantum è molto più semplice e veloce), ma ciò può andare a scapito delle prestazioni e della distribuzione.
Puoi usare tf.function
per creare grafici dai tuoi programmi. È uno strumento di trasformazione che crea grafici del flusso di dati indipendenti da Python dal codice Python. Questo ti aiuterà a creare modelli performanti e portatili ed è necessario utilizzare SavedModel
.
Questa guida ti aiuterà a concettualizzare come funziona tf.function
sotto il cofano, in modo da poterlo utilizzare in modo efficace.
I principali takeaway e raccomandazioni sono:
- Esegui il debug in modalità desiderosa, quindi decora con
@tf.function
. - Non fare affidamento sugli effetti collaterali di Python come la mutazione di oggetti o le aggiunte di elenchi.
-
tf.function
funziona meglio con TensorFlow ops; Le chiamate NumPy e Python vengono convertite in costanti.
Impostare
import tensorflow as tf
Definisci una funzione di supporto per dimostrare i tipi di errori che potresti riscontrare:
import traceback
import contextlib
# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
try:
yield
except error_class as e:
print('Caught expected exception \n {}:'.format(error_class))
traceback.print_exc(limit=2)
except Exception as e:
raise e
else:
raise Exception('Expected {} to be raised but no error was raised!'.format(
error_class))
Nozioni di base
Utilizzo
Una Function
che definisci (ad esempio applicando il decoratore di @tf.function
) è proprio come un'operazione di base di TensorFlow: puoi eseguirla avidamente; puoi calcolare i gradienti; e così via.
@tf.function # The decorator converts `add` into a `Function`.
def add(a, b):
return a + b
add(tf.ones([2, 2]), tf.ones([2, 2])) # [[2., 2.], [2., 2.]]
<tf.Tensor: shape=(2, 2), dtype=float32, numpy= array([[2., 2.], [2., 2.]], dtype=float32)>
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
result = add(v, 1.0)
tape.gradient(result, v)
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
È possibile utilizzare Function
s all'interno di altre Function
s.
@tf.function
def dense_layer(x, w, b):
return add(tf.matmul(x, w), b)
dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))
<tf.Tensor: shape=(3, 2), dtype=float32, numpy= array([[3., 3.], [3., 3.], [3., 3.]], dtype=float32)>
Function
s può essere più veloce del codice desideroso, specialmente per i grafici con molte piccole operazioni. Ma per i grafici con poche operazioni costose (come le convoluzioni), potresti non vedere molto accelerazione.
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)
@tf.function
def conv_fn(image):
return conv_layer(image)
image = tf.zeros([1, 200, 200, 100])
# Warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")
Eager conv: 0.006058974999177735 Function conv: 0.005791576000774512 Note how there's not much difference in performance for convolutions
Tracciamento
Questa sezione illustra come Function
la funzione nascosta, inclusi i dettagli di implementazione che potrebbero cambiare in futuro . Tuttavia, una volta capito perché e quando si verifica il tracciamento, è molto più semplice utilizzare tf.function
in modo efficace!
Che cos'è il "tracciamento"?
Una Function
esegue il programma in un grafico TensorFlow . Tuttavia, un tf.Graph
non può rappresentare tutte le cose che scriveresti in un programma TensorFlow desideroso. Ad esempio, Python supporta il polimorfismo, ma tf.Graph
richiede che i suoi input abbiano un tipo di dati e una dimensione specificati. Oppure puoi eseguire attività secondarie come leggere argomenti della riga di comando, generare un errore o lavorare con un oggetto Python più complesso; nessuna di queste cose può essere eseguita in un tf.Graph
.
Function
colma questa lacuna separando il codice in due fasi:
1) Nella prima fase, denominata " tracing ", Function
crea un nuovo tf.Graph
. Il codice Python viene eseguito normalmente, ma tutte le operazioni di TensorFlow (come l'aggiunta di due Tensor) sono posticipate : vengono acquisite da tf.Graph
e non vengono eseguite.
2) Nella seconda fase viene eseguito un tf.Graph
che contiene tutto ciò che è stato differito nella prima fase. Questa fase è molto più veloce della fase di tracciamento.
A seconda dei suoi input, Function
non eseguirà sempre la prima fase quando viene chiamata. Vedere "Regole di tracciamento" di seguito per avere un'idea migliore di come si ottiene tale determinazione. Saltare la prima fase ed eseguire solo la seconda fase è ciò che offre le elevate prestazioni di TensorFlow.
Quando Function
decide di tracciare, la fase di traccia è immediatamente seguita dalla seconda fase, quindi chiamando Function
crea ed esegue sia tf.Graph
. Successivamente vedrai come puoi eseguire solo la fase di tracciamento con get_concrete_function
.
Quando si passano argomenti di tipi diversi in un Function
, vengono eseguite entrambe le fasi:
@tf.function
def double(a):
print("Tracing with", a)
return a + a
print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()
Tracing with Tensor("a:0", shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32) Tracing with Tensor("a:0", shape=(), dtype=float32) tf.Tensor(2.2, shape=(), dtype=float32) Tracing with Tensor("a:0", shape=(), dtype=string) tf.Tensor(b'aa', shape=(), dtype=string)
Si noti che se si chiama ripetutamente una Function
con lo stesso tipo di argomento, TensorFlow salterà la fase di traccia e riutilizzerà un grafico tracciato in precedenza, poiché il grafico generato sarebbe identico.
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))
tf.Tensor(b'bb', shape=(), dtype=string)
Puoi usare pretty_printed_concrete_signatures()
per vedere tutte le tracce disponibili:
print(double.pretty_printed_concrete_signatures())
double(a) Args: a: int32 Tensor, shape=() Returns: int32 Tensor, shape=() double(a) Args: a: float32 Tensor, shape=() Returns: float32 Tensor, shape=() double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Finora, hai visto che tf.function
crea un livello di invio dinamico memorizzato nella cache sulla logica di tracciamento del grafico di TensorFlow. Per essere più precisi sulla terminologia:
- Un
tf.Graph
è la rappresentazione grezza, indipendente dal linguaggio e portatile di un calcolo TensorFlow. - Una
ConcreteFunction
avvolge untf.Graph
. - Una
Function
gestisce una cache diConcreteFunction
e seleziona quella giusta per i tuoi input. -
tf.function
wrapping di una funzione Python, restituendo un oggettoFunction
. - Tracing crea un
tf.Graph
e lo avvolge in unaConcreteFunction
, nota anche come traccia.
Regole di tracciamento
Una Function
determina se riutilizzare una ConcreteFunction
tracciata calcolando una chiave cache da args e kwargs di un input. Una chiave cache è una chiave che identifica una ConcreteFunction
in base agli input args e kwargs della Function
chiamata, secondo le seguenti regole (che possono cambiare):
- La chiave generata per un
tf.Tensor
è la sua forma e dtype. - La chiave generata per un
tf.Variable
è un ID variabile univoco. - La chiave generata per una primitiva Python (come
int
,float
,str
) è il suo valore. - La chiave generata per nidificati
dict
s,list
s,tuple
s,namedtuple
s eattr
s è la tupla appiattita di leaf-keys (vedinest.flatten
). (Come risultato di questo appiattimento, la chiamata di una funzione concreta con una struttura di nidificazione diversa da quella utilizzata durante la traccia risulterà in un TypeError). - Per tutti gli altri tipi di Python la chiave è univoca per l'oggetto. In questo modo una funzione o un metodo viene tracciato indipendentemente per ogni istanza con cui viene chiamato.
Controllo del ritracciamento
Il ritracciamento, ovvero quando la tua Function
crea più di una traccia, aiuta a garantire che TensorFlow generi grafici corretti per ogni set di input. Tuttavia, il tracciamento è un'operazione costosa! Se la tua Function
ritraccia un nuovo grafico per ogni chiamata, scoprirai che il tuo codice viene eseguito più lentamente che se non avessi usato tf.function
.
Per controllare il comportamento di traccia, puoi utilizzare le seguenti tecniche:
- Specifica
input_signature
intf.function
per limitare la traccia.
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
print("Tracing with", x)
return tf.where(x % 2 == 0, x // 2, 3 * x + 1)
print(next_collatz(tf.constant([1, 2])))
# You specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([[1, 2], [3, 4]]))
# You specified an int32 dtype in the input signature, so this should fail.
with assert_raises(ValueError):
next_collatz(tf.constant([1.0, 2.0]))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([4 1], shape=(2,), dtype=int32) Caught expected exception <class 'ValueError'>: Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 9, in <module> next_collatz(tf.constant([[1, 2], [3, 4]])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor( [[1 2] [3 4]], shape=(2, 2), dtype=int32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)). Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/1851403433.py", line 13, in <module> next_collatz(tf.constant([1.0, 2.0])) ValueError: Python inputs incompatible with input_signature: inputs: ( tf.Tensor([1. 2.], shape=(2,), dtype=float32)) input_signature: ( TensorSpec(shape=(None,), dtype=tf.int32, name=None)).
Specificare una dimensione [Nessuno] in
tf.TensorSpec
per consentire flessibilità nel riutilizzo delle tracce.Poiché TensorFlow abbina i tensori in base alla loro forma, l'utilizzo di una dimensione
None
come carattere jolly consentirà alleFunction
di riutilizzare le tracce per input di dimensioni variabili. L'input di dimensioni variabili può verificarsi se si dispone di sequenze di lunghezza diversa o immagini di dimensioni diverse per ciascun batch (consultare ad esempio i tutorial Transformer e Deep Dream ).
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def g(x):
print('Tracing with', x)
return x
# No retrace!
print(g(tf.constant([1, 2, 3])))
print(g(tf.constant([1, 2, 3, 4, 5])))
Tracing with Tensor("x:0", shape=(None,), dtype=int32) tf.Tensor([1 2 3], shape=(3,), dtype=int32) tf.Tensor([1 2 3 4 5], shape=(5,), dtype=int32)
Trasmetti argomenti Python a Tensors per ridurre il ritracciamento.
Spesso, gli argomenti Python vengono utilizzati per controllare gli iperparametri e le costruzioni di grafi, ad esempio
num_layers=10
otraining=True
ononlinearity='relu'
. Quindi, se l'argomento Python cambia, ha senso che dovresti ripercorrere il grafico.Tuttavia, è possibile che un argomento Python non venga utilizzato per controllare la costruzione del grafico. In questi casi, una modifica del valore Python può attivare un inutile ritracciamento. Prendi, ad esempio, questo ciclo di addestramento, che AutoGraph srotolerà dinamicamente. Nonostante le tracce multiple, il grafico generato è in realtà identico, quindi il ritracciamento non è necessario.
def train_one_step():
pass
@tf.function
def train(num_steps):
print("Tracing with num_steps = ", num_steps)
tf.print("Executing with num_steps = ", num_steps)
for _ in tf.range(num_steps):
train_one_step()
print("Retracing occurs for different Python arguments.")
train(num_steps=10)
train(num_steps=20)
print()
print("Traces are reused for Tensor arguments.")
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))
Retracing occurs for different Python arguments. Tracing with num_steps = 10 Executing with num_steps = 10 Tracing with num_steps = 20 Executing with num_steps = 20 Traces are reused for Tensor arguments. Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32) Executing with num_steps = 10 Executing with num_steps = 20
Se è necessario forzare il ritracciamento, creare una nuova Function
. È garantito che gli oggetti Function
separati non condividano tracce.
def f():
print('Tracing!')
tf.print('Executing')
tf.function(f)()
tf.function(f)()
Tracing! Executing Tracing! Executing
Ottenere funzioni concrete
Ogni volta che viene tracciata una funzione, viene creata una nuova funzione concreta. Puoi ottenere direttamente una funzione concreta, usando get_concrete_function
.
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.constant("a"))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
Obtaining concrete trace Executing traced function tf.Tensor(b'aa', shape=(), dtype=string) tf.Tensor(b'bb', shape=(), dtype=string)
# You can also call get_concrete_function on an InputSpec
double_strings_from_inputspec = double.get_concrete_function(tf.TensorSpec(shape=[], dtype=tf.string))
print(double_strings_from_inputspec(tf.constant("c")))
tf.Tensor(b'cc', shape=(), dtype=string)
La stampa di una ConcreteFunction
mostra un riepilogo dei suoi argomenti di input (con tipi) e del suo tipo di output.
print(double_strings)
ConcreteFunction double(a) Args: a: string Tensor, shape=() Returns: string Tensor, shape=()
Puoi anche recuperare direttamente la firma di una funzione concreta.
print(double_strings.structured_input_signature)
print(double_strings.structured_outputs)
((TensorSpec(shape=(), dtype=tf.string, name='a'),), {}) Tensor("Identity:0", shape=(), dtype=string)
L'utilizzo di una traccia concreta con tipi incompatibili genererà un errore
with assert_raises(tf.errors.InvalidArgumentError):
double_strings(tf.constant(1))
Caught expected exception <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3196284684.py", line 2, in <module> double_strings(tf.constant(1)) tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_162 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_162]
Potresti notare che agli argomenti Python viene riservato un trattamento speciale nella firma di input di una funzione concreta. Prima di TensorFlow 2.3, gli argomenti Python venivano semplicemente rimossi dalla firma della funzione concreta. A partire da TensorFlow 2.3, gli argomenti Python rimangono nella firma, ma sono vincolati a prendere il valore impostato durante la traccia.
@tf.function
def pow(a, b):
return a ** b
square = pow.get_concrete_function(a=tf.TensorSpec(None, tf.float32), b=2)
print(square)
ConcreteFunction pow(a, b=2) Args: a: float32 Tensor, shape=<unknown> Returns: float32 Tensor, shape=<unknown>
assert square(tf.constant(10.0)) == 100
with assert_raises(TypeError):
square(tf.constant(10.0), b=3)
Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1721, in _call_impl cancellation_manager) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1765, in _call_with_flat_signature raise TypeError(f"{self._flat_signature_summary()} got unexpected " TypeError: pow(a) got unexpected keyword arguments: b. During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2310937119.py", line 4, in <module> square(tf.constant(10.0), b=3) TypeError: ConcreteFunction pow(a, b) was constructed with int value 2 in b, but was called with int value 3.
Ottenere grafici
Ogni funzione concreta è un wrapper richiamabile attorno a un tf.Graph
. Sebbene il recupero dell'oggetto tf.Graph
effettivo non sia qualcosa che normalmente devi fare, puoi ottenerlo facilmente da qualsiasi funzione concreta.
graph = double_strings.graph
for node in graph.as_graph_def().node:
print(f'{node.input} -> {node.name}')
[] -> a ['a', 'a'] -> add ['add'] -> Identity
Debug
In generale, il debug del codice è più semplice in modalità desiderosa che all'interno di tf.function
. Dovresti assicurarti che il tuo codice venga eseguito senza errori in modalità desiderosa prima di decorare con tf.function
. Per assistere nel processo di debug, puoi chiamare tf.config.run_functions_eagerly(True)
per disabilitare e riattivare globalmente tf.function
.
Quando si rintracciano problemi che compaiono solo all'interno di tf.function
, ecco alcuni suggerimenti:
- Le normali vecchie chiamate di
print
Python vengono eseguite solo durante la traccia, aiutandoti a rintracciare quando la tua funzione viene (ri)tracciata. - Le chiamate
tf.print
verranno eseguite ogni volta e possono aiutarti a rintracciare i valori intermedi durante l'esecuzione. -
tf.debugging.enable_check_numerics
è un modo semplice per rintracciare dove vengono creati NaN e Inf. -
pdb
(il debugger di Python ) può aiutarti a capire cosa sta succedendo durante la traccia. (Attenzione:pdb
ti porterà nel codice sorgente trasformato in AutoGraph.)
Trasformazioni di autografi
AutoGraph è una libreria che è attiva per impostazione predefinita in tf.function
e trasforma un sottoinsieme di codice desideroso di Python in operazioni TensorFlow compatibili con i grafici. Ciò include il flusso di controllo come if
, for
, while
.
Le operazioni di TensorFlow come tf.cond
e tf.while_loop
continuano a funzionare, ma il flusso di controllo è spesso più facile da scrivere e capire quando è scritto in Python.
# A simple loop
@tf.function
def f(x):
while tf.reduce_sum(x) > 1:
tf.print(x)
x = tf.tanh(x)
return x
f(tf.random.uniform([5]))
[0.666458249 0.713946581 0.723879576 0.330758929 0.184087753] [0.582645297 0.613145649 0.619306684 0.319202513 0.182036072] [0.524585426 0.546337605 0.550645113 0.308785647 0.18005164] [0.481231302 0.497770309 0.501003504 0.299331933 0.178130865] [0.447229207 0.460361809 0.462906033 0.290701121 0.176270396] [0.419618756 0.430379033 0.432449728 0.282779962 0.174467146] [0.396609187 0.405638 0.407366514 0.275476 0.172718227] [0.377043903 0.384762734 0.386234313 0.268712848 0.17102097] [0.360137492 0.366836458 0.368109286 0.262426734 0.169372901] [0.345335096 0.351221472 0.352336824 0.256563932 0.167771652] [0.332231969 0.337458342 0.338446289 0.251078814 0.166215062] [0.320524871 0.325206399 0.326089561 0.24593246 0.164701089] [0.309981436 0.314206958 0.31500268 0.241091311 0.163227797] [0.300420195 0.304259449 0.304981351 0.236526251 0.161793426] [0.291697085 0.295205742 0.295864582 0.232211992 0.160396278] [0.283696055 0.286919087 0.287523568 0.228126258 0.159034774] [0.276322395 0.279296666 0.27985391 0.224249557 0.157707423] [0.269497961 0.272254 0.272769839 0.220564634 0.15641281] [0.263157606 0.265720904 0.266200244 0.21705614 0.155149609] [0.257246554 0.259638608 0.260085613 0.213710397 0.153916568] [0.251718313 0.25395745 0.254375577 0.210515186 0.152712509] [0.246533215 0.248635098 0.249027327 0.207459539 0.151536316] [0.241657034 0.243635193 0.244004101 0.204533577 0.15038693] [0.237060249 0.238926381 0.239274174 0.201728329 0.149263337] [0.232717097 0.234481394 0.234810054 0.199035719 0.148164615] [0.228605017 0.230276451 0.230587661 0.196448416 0.147089839] [0.224704206 0.226290658 0.22658591 0.193959698 0.14603813] [0.220997125 0.222505584 0.222786173 0.191563457 0.145008713] <tf.Tensor: shape=(5,), dtype=float32, numpy= array([0.21746822, 0.21890487, 0.21917202, 0.18925412, 0.14400077], dtype=float32)>
Se sei curioso puoi controllare il codice generato dall'autografo.
print(tf.autograph.to_code(f.python_function))
def tf__f(x): with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope: do_return = False retval_ = ag__.UndefinedReturnValue() def get_state(): return (x,) def set_state(vars_): nonlocal x (x,) = vars_ def loop_body(): nonlocal x ag__.converted_call(ag__.ld(tf).print, (ag__.ld(x),), None, fscope) x = ag__.converted_call(ag__.ld(tf).tanh, (ag__.ld(x),), None, fscope) def loop_test(): return (ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.ld(x),), None, fscope) > 1) ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {}) try: do_return = True retval_ = ag__.ld(x) except: do_return = False raise return fscope.ret(retval_, do_return)
Condizionali
AutoGraph converte alcune istruzioni if <condition>
nelle chiamate tf.cond
equivalenti. Questa sostituzione viene effettuata se <condition>
è un Tensor. In caso contrario, l'istruzione if
viene eseguita come condizionale Python.
Un condizionale Python viene eseguito durante il tracciamento, quindi verrà aggiunto esattamente un ramo del condizionale al grafico. Senza AutoGraph, questo grafico tracciato non sarebbe in grado di prendere il ramo alternativo se esiste un flusso di controllo dipendente dai dati.
tf.cond
traccia e aggiunge entrambi i rami del condizionale al grafico, selezionando dinamicamente un ramo al momento dell'esecuzione. Il tracciamento può avere effetti collaterali indesiderati; controlla gli effetti di traccia di AutoGraph per ulteriori informazioni.
@tf.function
def fizzbuzz(n):
for i in tf.range(1, n + 1):
print('Tracing for loop')
if i % 15 == 0:
print('Tracing fizzbuzz branch')
tf.print('fizzbuzz')
elif i % 3 == 0:
print('Tracing fizz branch')
tf.print('fizz')
elif i % 5 == 0:
print('Tracing buzz branch')
tf.print('buzz')
else:
print('Tracing default branch')
tf.print(i)
fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))
Tracing for loop Tracing fizzbuzz branch Tracing fizz branch Tracing buzz branch Tracing default branch 1 2 fizz 4 buzz 1 2 fizz 4 buzz fizz 7 8 fizz buzz 11 fizz 13 14 fizzbuzz 16 17 fizz 19 buzz
Consultare la documentazione di riferimento per ulteriori restrizioni sulle istruzioni if convertite in AutoGraph.
Cicli
AutoGraph converte alcune istruzioni for
e while
nelle operazioni di loop TensorFlow equivalenti, come tf.while_loop
. Se non viene convertito, il ciclo for
o while
viene eseguito come ciclo Python.
Tale sostituzione avviene nelle seguenti situazioni:
-
for x in y
: sey
è un Tensor, converti intf.while_loop
. Nel caso speciale in cuiy
è untf.data.Dataset
, viene generata una combinazione ditf.data.Dataset
ops. -
while <condition>
: se<condition>
è un tensore, converti intf.while_loop
.
Un ciclo Python viene eseguito durante la traccia, aggiungendo operazioni aggiuntive a tf.Graph
per ogni iterazione del ciclo.
Un ciclo TensorFlow traccia il corpo del ciclo e seleziona dinamicamente quante iterazioni eseguire al momento dell'esecuzione. Il corpo del ciclo appare solo una volta nel tf.Graph
generato.
Consultare la documentazione di riferimento per ulteriori restrizioni sulle istruzioni for
e while
convertite in AutoGraph.
Ciclo sui dati Python
Una trappola comune è quella di scorrere i dati Python/NumPy all'interno di un tf.function
. Questo ciclo verrà eseguito durante il processo di tracciamento, aggiungendo una copia del tuo modello al tf.Graph
per ogni iterazione del ciclo.
Se si desidera eseguire il wrapping dell'intero ciclo di addestramento in tf.function
, il modo più sicuro per eseguire questa operazione è racchiudere i dati come tf.data.Dataset
in modo che AutoGraph srotoli dinamicamente il ciclo di addestramento.
def measure_graph_size(f, *args):
g = f.get_concrete_function(*args).graph
print("{}({}) contains {} nodes in its graph".format(
f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))
@tf.function
def train(dataset):
loss = tf.constant(0)
for x, y in dataset:
loss += tf.abs(y - x) # Some dummy computation.
return loss
small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
lambda: big_data, (tf.int32, tf.int32)))
train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 6 nodes in its graph
Quando si avvolgono i dati Python/NumPy in un set di dati, prestare attenzione a tf.data.Dataset.from_generator
rispetto tf.data.Dataset.from_tensors
. Il primo manterrà i dati in Python e li recupererà tramite tf.py_function
che può avere implicazioni sulle prestazioni, mentre il secondo raggruppa una copia dei dati come un grande nodo tf.constant()
nel grafico, che può avere implicazioni sulla memoria.
Leggere i dati dai file tramite TFRecordDataset
, CsvDataset
, ecc. è il modo più efficace per consumare dati, poiché TensorFlow stesso può gestire il caricamento asincrono e il prelettura dei dati, senza dover coinvolgere Python. Per ulteriori informazioni, vedere la guida alle pipeline di input tf.data
: Build TensorFlow .
Accumulo di valori in un ciclo
Un modello comune consiste nell'accumulare valori intermedi da un ciclo. Normalmente, ciò si ottiene aggiungendo a un elenco Python o aggiungendo voci a un dizionario Python. Tuttavia, poiché si tratta di effetti collaterali di Python, non funzioneranno come previsto in un ciclo svolto dinamicamente. Utilizzare tf.TensorArray
per accumulare risultati da un ciclo svolto dinamicamente.
batch_size = 2
seq_len = 3
feature_size = 4
def rnn_step(inp, state):
return inp + state
@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
max_seq_len = input_data.shape[0]
states = tf.TensorArray(tf.float32, size=max_seq_len)
state = initial_state
for i in tf.range(max_seq_len):
state = rnn_step(input_data[i], state)
states = states.write(i, state)
return tf.transpose(states.stack(), [1, 0, 2])
dynamic_rnn(rnn_step,
tf.random.uniform([batch_size, seq_len, feature_size]),
tf.zeros([batch_size, feature_size]))
<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy= array([[[0.06309307, 0.9938811 , 0.90789986, 0.42136216], [0.44997275, 1.9107027 , 1.0716251 , 0.717237 ], [0.6026064 , 2.1622117 , 1.4164022 , 1.4153863 ]], [[0.04946005, 0.69127274, 0.56848884, 0.22406638], [0.8148316 , 1.0278493 , 0.6207781 , 1.1935129 ], [0.9178308 , 1.320889 , 0.989761 , 2.0120025 ]]], dtype=float32)>
Limitazioni
Function
TensorFlow presenta alcune limitazioni di progettazione di cui dovresti essere a conoscenza durante la conversione di una funzione Python in una Function
.
Esecuzione degli effetti collaterali di Python
Gli effetti collaterali, come la stampa, l'aggiunta a elenchi e la mutazione di globali, possono comportarsi in modo imprevisto all'interno di una Function
, a volte eseguendo due volte o non tutti. Si verificano solo la prima volta che si chiama una Function
con un insieme di input. Successivamente, il tracciato tf.Graph
viene rieseguito, senza eseguire il codice Python.
La regola generale è evitare di fare affidamento sugli effetti collaterali di Python nella logica e usarli solo per eseguire il debug delle tracce. In caso contrario, le API TensorFlow come tf.data
, tf.print
, tf.summary
, tf.Variable.assign
e tf.TensorArray
sono il modo migliore per garantire che il codice venga eseguito dal runtime TensorFlow ad ogni chiamata.
@tf.function
def f(x):
print("Traced with", x)
tf.print("Executed with", x)
f(1)
f(1)
f(2)
Traced with 1 Executed with 1 Executed with 1 Traced with 2 Executed with 2
Se desideri eseguire il codice Python durante ogni invocazione di una Function
, tf.py_function
è un portello di uscita. Lo svantaggio di tf.py_function
è che non è portatile o particolarmente performante, non può essere salvato con SavedModel e non funziona bene in configurazioni distribuite (multi-GPU, TPU). Inoltre, poiché tf.py_function
deve essere cablato nel grafico, trasmette tutti gli input/output ai tensori.
Modifica delle variabili globali e libere di Python
La modifica delle variabili globali e libere di Python conta come un effetto collaterale di Python, quindi avviene solo durante il tracciamento.
external_list = []
@tf.function
def side_effect(x):
print('Python side effect')
external_list.append(x)
side_effect(1)
side_effect(1)
side_effect(1)
# The list append only happened once!
assert len(external_list) == 1
Python side effect
A volte i comportamenti imprevisti sono molto difficili da notare. Nell'esempio seguente, il counter
lo scopo di salvaguardare l'incremento di una variabile. Tuttavia, poiché è un intero Python e non un oggetto TensorFlow, il suo valore viene acquisito durante la prima traccia. Quando viene utilizzata la assign_add
tf.function
registrato incondizionatamente nel grafico sottostante. Pertanto v
aumenterà di 1, ogni volta che viene chiamata la tf.function
. Questo problema è comune tra gli utenti che tentano di migrare il codice Tensorflow in modalità Grpah su Tensorflow 2 utilizzando i decoratori tf.function
, quando gli effetti collaterali Python (il counter
nell'esempio) vengono utilizzati per determinare quali operazioni eseguire ( assign_add
nell'esempio ). Di solito, gli utenti se ne rendono conto solo dopo aver visto risultati numerici sospetti, o prestazioni significativamente inferiori alle attese (ad esempio se l'operazione protetta è molto costosa).
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# A python side-effect
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 2, 3
1 2 3
Una soluzione alternativa per ottenere il comportamento previsto consiste nell'usare tf.init_scope
per sollevare le operazioni al di fuori del grafico della funzione. Ciò garantisce che l'incremento della variabile venga eseguito solo una volta durante il tempo di tracciamento. Va notato init_scope
ha altri effetti collaterali tra cui il flusso di controllo cancellato e il nastro gradiente. A volte l'uso di init_scope
può diventare troppo complesso per essere gestito in modo realistico.
class Model(tf.Module):
def __init__(self):
self.v = tf.Variable(0)
self.counter = 0
@tf.function
def __call__(self):
if self.counter == 0:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.counter += 1
self.v.assign_add(1)
return self.v
m = Model()
for n in range(3):
print(m().numpy()) # prints 1, 1, 1
1 1 1
In sintesi, come regola pratica, dovresti evitare di mutare oggetti Python come interi o contenitori come elenchi che risiedono al di fuori di Function
. Invece, usa argomenti e oggetti TF. Ad esempio, la sezione "Accumulo di valori in un ciclo" contiene un esempio di come è possibile implementare operazioni di tipo elenco.
È possibile, in alcuni casi, acquisire e manipolare lo stato se è un tf.Variable
. In questo modo i pesi dei modelli Keras vengono aggiornati con ripetute chiamate alla stessa ConcreteFunction
.
Utilizzo di iteratori e generatori Python
Molte funzionalità di Python, come generatori e iteratori, si basano sul runtime Python per tenere traccia dello stato. In generale, sebbene questi costrutti funzionino come previsto in modalità ansiosa, sono esempi di effetti collaterali di Python e quindi si verificano solo durante il tracciamento.
@tf.function
def buggy_consume_next(iterator):
tf.print("Value:", next(iterator))
iterator = iter([1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)
Value: 1 Value: 1 Value: 1
Proprio come TensorFlow ha un tf.TensorArray
specializzato per i costrutti di elenco, ha un tf.data.Iterator
specializzato per i costrutti di iterazione. Vedere la sezione sulle trasformazioni di AutoGraph per una panoramica. Inoltre, l'API tf.data
può aiutare a implementare modelli di generatore:
@tf.function
def good_consume_next(iterator):
# This is ok, iterator is a tf.data.Iterator
tf.print("Value:", next(iterator))
ds = tf.data.Dataset.from_tensor_slices([1, 2, 3])
iterator = iter(ds)
good_consume_next(iterator)
good_consume_next(iterator)
good_consume_next(iterator)
Value: 1 Value: 2 Value: 3
Tutti gli output di una tf.function devono essere valori di ritorno
Ad eccezione di tf.Variable
s, una tf.function deve restituire tutti i suoi output. Il tentativo di accedere direttamente a qualsiasi tensore da una funzione senza passare attraverso i valori di ritorno provoca "perdite".
Ad esempio, la funzione seguente "perde" il tensore a
tramite Python global x
:
x = None
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return a + 2
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
3 'Tensor' object has no attribute 'numpy'
Questo è vero anche se viene restituito anche il valore trapelato:
@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor
correct_a = leaky_function(tf.constant(1))
print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)
@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b
with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
2 'Tensor' object has no attribute 'numpy' Caught expected exception <class 'TypeError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/566849597.py", line 21, in <module> captures_leaked_tensor(tf.constant(2)) TypeError: Originated from a graph execution error. The graph execution error is detected at a node built at (most recent call last): >>> File /usr/lib/python3.7/runpy.py, line 193, in _run_module_as_main >>> File /usr/lib/python3.7/runpy.py, line 85, in _run_code >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel_launcher.py, line 16, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/traitlets/config/application.py, line 846, in launch_instance >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelapp.py, line 677, in start >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tornado/platform/asyncio.py, line 199, in start >>> File /usr/lib/python3.7/asyncio/base_events.py, line 534, in run_forever >>> File /usr/lib/python3.7/asyncio/base_events.py, line 1771, in _run_once >>> File /usr/lib/python3.7/asyncio/events.py, line 88, in _run >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 457, in dispatch_queue >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 446, in process_one >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 353, in dispatch_shell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/kernelbase.py, line 648, in execute_request >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/ipkernel.py, line 353, in do_execute >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/ipykernel/zmqshell.py, line 533, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2902, in run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 2947, in _run_cell >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/async_helpers.py, line 68, in _pseudo_sync_runner >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3173, in run_cell_async >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3364, in run_ast_nodes >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/IPython/core/interactiveshell.py, line 3444, in run_code >>> File /tmp/ipykernel_26244/566849597.py, line 7, in <module> >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 910, in __call__ >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 958, in _call >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 781, in _initialize >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3157, in _get_concrete_function_internal_garbage_collected >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3557, in _maybe_define_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/function.py, line 3402, in _create_graph_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1143, in func_graph_from_py_func >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py, line 672, in wrapped_fn >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 1125, in autograph_handler >>> File /tmp/ipykernel_26244/566849597.py, line 4, in leaky_function >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1383, in binary_op_wrapper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/traceback_utils.py, line 150, in error_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py, line 1096, in op_dispatch_handler >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py, line 1737, in _add_dispatch >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py, line 476, in add_v2 >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py, line 746, in _apply_op_helper >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py, line 691, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 3705, in _create_op_internal >>> File /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/ops.py, line 2101, in __init__ Error detected in node 'add' defined at: File "/tmp/ipykernel_26244/566849597.py", line 4, in leaky_function TypeError: tf.Graph captured an external symbolic tensor. The symbolic tensor 'add:0' created by node 'add' is captured by the tf.Graph being executed as an input. But a tf.Graph is not allowed to take symbolic tensors from another graph as its inputs. Make sure all captured inputs of the executing tf.Graph are not symbolic tensors. Use return values, explicit Python locals or TensorFlow collections to access it. Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values for more information.
Di solito, perdite come queste si verificano quando si utilizzano istruzioni o strutture di dati Python. Oltre a far trapelare tensori inaccessibili, è probabile che tali istruzioni siano anche sbagliate perché contano come effetti collaterali di Python e non è garantito che vengano eseguite ad ogni chiamata di funzione.
I modi comuni per perdere i tensori locali includono anche la mutazione di una raccolta Python esterna o di un oggetto:
class MyClass:
def __init__(self):
self.field = None
external_list = []
external_object = MyClass()
def leaky_function():
a = tf.constant(1)
external_list.append(a) # Bad - leaks tensor
external_object.field = a # Bad - leaks tensor
Le funzioni ricorsive tf. non sono supportate
Le Function
ricorsive non sono supportate e potrebbero causare loop infiniti. Per esempio,
@tf.function
def recursive_fn(n):
if n > 0:
return recursive_fn(n - 1)
else:
return 1
with assert_raises(Exception):
recursive_fn(tf.constant(5)) # Bad - maximum recursion error.
Caught expected exception <class 'Exception'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/2233998312.py", line 9, in <module> recursive_fn(tf.constant(5)) # Bad - maximum recursion error. tensorflow.python.autograph.impl.api.StagingError: in user code: File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 4, in recursive_fn * return recursive_fn(n - 1) File "/tmp/ipykernel_26244/2233998312.py", line 3, in recursive_fn * if n > 0: File "/usr/lib/python3.7/abc.py", line 139, in __instancecheck__ return _abc_instancecheck(cls, instance) RecursionError: maximum recursion depth exceeded while calling a Python object
Anche se una Function
ricorsiva sembra funzionare, la funzione python verrà tracciata più volte e potrebbe avere implicazioni sulle prestazioni. Per esempio,
@tf.function
def recursive_fn(n):
if n > 0:
print('tracing')
return recursive_fn(n - 1)
else:
return 1
recursive_fn(5) # Warning - multiple tracings
tracing tracing tracing tracing tracing <tf.Tensor: shape=(), dtype=int32, numpy=1>
Problemi noti
Se la Function
non viene valutata correttamente, l'errore potrebbe essere spiegato da questi problemi noti che dovrebbero essere risolti in futuro.
A seconda delle variabili globali e libere di Python
Function
crea una nuova ConcreteFunction
quando viene chiamata con un nuovo valore di un argomento Python. Tuttavia, non lo fa per la chiusura Python, i globali o i non locali di quella Function
. Se il loro valore cambia tra le chiamate a Function
, la Function
utilizzerà comunque i valori che avevano quando è stata tracciata. Questo è diverso da come funzionano le normali funzioni Python.
Per questo motivo, dovresti seguire uno stile di programmazione funzionale che utilizza argomenti invece di chiudersi su nomi esterni.
@tf.function
def buggy_add():
return 1 + foo
@tf.function
def recommended_add(foo):
return 1 + foo
foo = 1
print("Buggy:", buggy_add())
print("Correct:", recommended_add(foo))
Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo = 100
print("Buggy:", buggy_add()) # Did not change!
print("Correct:", recommended_add(foo))
Updating the value of `foo` to 100! Buggy: tf.Tensor(2, shape=(), dtype=int32) Correct: tf.Tensor(101, shape=(), dtype=int32)
Un altro modo per aggiornare un valore globale consiste nel renderlo tf.Variable
e utilizzare invece il metodo Variable.assign
.
@tf.function
def variable_add():
return 1 + foo
foo = tf.Variable(1)
print("Variable:", variable_add())
Variable: tf.Tensor(2, shape=(), dtype=int32)
print("Updating the value of `foo` to 100!")
foo.assign(100)
print("Variable:", variable_add())
Updating the value of `foo` to 100! Variable: tf.Tensor(101, shape=(), dtype=int32)
A seconda degli oggetti Python
La raccomandazione di passare oggetti Python come argomenti in tf.function
ha una serie di problemi noti, che dovrebbero essere risolti in futuro. In generale, puoi fare affidamento su una traccia coerente se usi una primitiva Python o una struttura tf.nest
come argomento o passi un'istanza diversa di un oggetto in un Function
. Tuttavia, Function
non creerà una nuova traccia quando si passa lo stesso oggetto e si modificano solo i suoi attributi .
class SimpleModel(tf.Module):
def __init__(self):
# These values are *not* tf.Variables.
self.bias = 0.
self.weight = 2.
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
simple_model = SimpleModel()
x = tf.constant(10.)
print(evaluate(simple_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
simple_model.bias += 5.0
print(evaluate(simple_model, x)) # Didn't change :(
Adding bias! tf.Tensor(20.0, shape=(), dtype=float32)
L'utilizzo della stessa Function
per valutare l'istanza aggiornata del modello sarà difettoso poiché il modello aggiornato ha la stessa chiave di cache del modello originale.
Per questo motivo, ti consigliamo di scrivere la tua Function
per evitare di dipendere da attributi di oggetti mutabili o di creare nuovi oggetti.
Se ciò non è possibile, una soluzione alternativa consiste nel creare nuove Function
ogni volta che modifichi l'oggetto per forzare il ritracciamento:
def evaluate(model, x):
return model.weight * x + model.bias
new_model = SimpleModel()
evaluate_no_bias = tf.function(evaluate).get_concrete_function(new_model, x)
# Don't pass in `new_model`, `Function` already captured its state during tracing.
print(evaluate_no_bias(x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
new_model.bias += 5.0
# Create new Function and ConcreteFunction since you modified new_model.
evaluate_with_bias = tf.function(evaluate).get_concrete_function(new_model, x)
print(evaluate_with_bias(x)) # Don't pass in `new_model`.
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
Poiché il ritracciamento può essere costoso , puoi usare tf.Variable
s come attributi dell'oggetto, che possono essere mutati (ma non modificati, attenzione!) per un effetto simile senza bisogno di un ritracciamento.
class BetterModel:
def __init__(self):
self.bias = tf.Variable(0.)
self.weight = tf.Variable(2.)
@tf.function
def evaluate(model, x):
return model.weight * x + model.bias
better_model = BetterModel()
print(evaluate(better_model, x))
tf.Tensor(20.0, shape=(), dtype=float32)
print("Adding bias!")
better_model.bias.assign_add(5.0) # Note: instead of better_model.bias += 5
print(evaluate(better_model, x)) # This works!
Adding bias! tf.Tensor(25.0, shape=(), dtype=float32)
Creazione di tf.Variabili
Function
supporta solo tf.Variable
singleton creati una volta alla prima chiamata e riutilizzati nelle successive chiamate di funzione. Il frammento di codice seguente creerebbe un nuovo tf.Variable
in ogni chiamata di funzione, che si traduce in un'eccezione ValueError
.
Esempio:
@tf.function
def f(x):
v = tf.Variable(1.0)
return v
with assert_raises(ValueError):
f(1.0)
Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3018268426.py", line 7, in <module> f(1.0) ValueError: in user code: File "/tmp/ipykernel_26244/3018268426.py", line 3, in f * v = tf.Variable(1.0) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Un modello comune utilizzato per aggirare questa limitazione consiste nell'iniziare con un valore Python None, quindi creare in modo condizionale tf.Variable
se il valore è None:
class Count(tf.Module):
def __init__(self):
self.count = None
@tf.function
def __call__(self):
if self.count is None:
self.count = tf.Variable(0)
return self.count.assign_add(1)
c = Count()
print(c())
print(c())
tf.Tensor(1, shape=(), dtype=int32) tf.Tensor(2, shape=(), dtype=int32)
Utilizzo con più ottimizzatori Keras
Potresti riscontrare ValueError: tf.function only supports singleton tf.Variables created on the first call.
quando si utilizza più di un ottimizzatore Keras con una tf.function
. Questo errore si verifica perché gli ottimizzatori creano internamente tf.Variables
quando applicano i gradienti per la prima volta.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
@tf.function
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
train_step(w, x, y, opt1)
print("Calling `train_step` with different optimizer...")
with assert_raises(ValueError):
train_step(w, x, y, opt2)
Calling `train_step` with different optimizer... Caught expected exception <class 'ValueError'>: Traceback (most recent call last): File "/tmp/ipykernel_26244/3551158538.py", line 8, in assert_raises yield File "/tmp/ipykernel_26244/3167358578.py", line 18, in <module> train_step(w, x, y, opt2) ValueError: in user code: File "/tmp/ipykernel_26244/3167358578.py", line 9, in train_step * optimizer.apply_gradients(zip(gradients, [w])) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 639, in apply_gradients ** self._create_all_weights(var_list) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 828, in _create_all_weights _ = self.iterations File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 835, in __getattribute__ return super(OptimizerV2, self).__getattribute__(name) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 995, in iterations aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py", line 1202, in add_weight aggregation=aggregation) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_utils.py", line 129, in make_variable shape=variable_shape if variable_shape else None) ValueError: tf.function only supports singleton tf.Variables created on the first call. Make sure the tf.Variable is only created once or created outside tf.function. See https://www.tensorflow.org/guide/function#creating_tfvariables for more information.
Se è necessario modificare l'ottimizzatore durante l'addestramento, una soluzione alternativa consiste nel creare una nuova Function
per ogni ottimizzatore, chiamando direttamente ConcreteFunction
.
opt1 = tf.keras.optimizers.Adam(learning_rate = 1e-2)
opt2 = tf.keras.optimizers.Adam(learning_rate = 1e-3)
# Not a tf.function.
def train_step(w, x, y, optimizer):
with tf.GradientTape() as tape:
L = tf.reduce_sum(tf.square(w*x - y))
gradients = tape.gradient(L, [w])
optimizer.apply_gradients(zip(gradients, [w]))
w = tf.Variable(2.)
x = tf.constant([-1.])
y = tf.constant([2.])
# Make a new Function and ConcreteFunction for each optimizer.
train_step_1 = tf.function(train_step).get_concrete_function(w, x, y, opt1)
train_step_2 = tf.function(train_step).get_concrete_function(w, x, y, opt2)
for i in range(10):
if i % 2 == 0:
train_step_1(w, x, y) # `opt1` is not used as a parameter.
else:
train_step_2(w, x, y) # `opt2` is not used as a parameter.
Utilizzo con più modelli Keras
Potresti anche riscontrare ValueError: tf.function only supports singleton tf.Variables created on the first call.
quando si passano diverse istanze del modello alla stessa Function
.
Questo errore si verifica perché i modelli Keras (che non hanno la loro forma di input definita ) e i livelli Keras creano tf.Variables
quando vengono chiamati per la prima volta. Potresti tentare di inizializzare quelle variabili all'interno di un Function
, che è già stato chiamato. Per evitare questo errore, prova a chiamare model.build(input_shape)
per inizializzare tutti i pesi prima di addestrare il modello.
Ulteriori letture
Per informazioni su come esportare e caricare una Function
, vedere la guida SavedModel . Per ulteriori informazioni sulle ottimizzazioni dei grafici eseguite dopo il tracciamento, consulta la guida Grappler . Per informazioni su come ottimizzare la pipeline di dati e profilare il modello, consulta la guida di Profiler .