Zobacz na TensorFlow.org | Uruchom w Google Colab | Zobacz na GitHub | Pobierz notatnik |
Przegląd
Ten przewodnik zawiera listę najlepszych praktyk dotyczących pisania kodu przy użyciu TensorFlow 2 (TF2), jest napisany dla użytkowników, którzy niedawno przeszli z TensorFlow 1 (TF1). Zapoznaj się z sekcją dotyczącą migracji w przewodniku, aby uzyskać więcej informacji na temat migracji kodu TF1 do TF2.
Ustawiać
Zaimportuj TensorFlow i inne zależności dla przykładów w tym przewodniku.
import tensorflow as tf
import tensorflow_datasets as tfds
Zalecenia dotyczące idiomatycznego TensorFlow 2
Refaktoryzuj swój kod na mniejsze moduły
Dobrą praktyką jest refaktoryzacja kodu na mniejsze funkcje, które są wywoływane w razie potrzeby. Aby uzyskać najlepszą wydajność, powinieneś spróbować udekorować największe bloki obliczeń, które możesz w tf.function
(zauważ, że zagnieżdżone funkcje Pythona wywoływane przez tf.function
nie wymagają oddzielnych dekoracji, chyba że chcesz użyć innego jit_compile
ustawienia funkcji tf.function
. W zależności od przypadku użycia może to być wiele kroków treningowych lub nawet cała pętla treningowa. W przypadku użycia wnioskowania może to być pojedynczy przebieg modelu.
Dostosuj domyślną szybkość uczenia się dla niektórych tf.keras.optimizer
s
Niektóre optymalizatory Keras mają różne szybkości uczenia się w TF2. Jeśli zauważysz zmianę w zachowaniu zbieżności modeli, sprawdź domyślne współczynniki uczenia się.
Nie wprowadzono żadnych zmian w optimizers.RMSprop
, optimizers.SGD
optimizers.Adam
.
Zmieniły się następujące domyślne współczynniki uczenia się:
-
optimizers.Adagrad
od0.01
do0.001
-
optimizers.Adadelta
od1.0
do0.001
-
optimizers.Adamax
od0.002
do0.001
-
optimizers.Nadam
od0.002
do0.001
Użyj tf.Module
s i Keras do zarządzania zmiennymi
tf.Module
s i tf.keras.layers.Layer
s oferują wygodne variables
i właściwości trainable_variables
, które rekurencyjnie gromadzą wszystkie zmienne zależne. Ułatwia to lokalne zarządzanie zmiennymi tam, gdzie są używane.
Warstwy/modele Keras dziedziczą po tf.train.Checkpointable
i są zintegrowane z @tf.function
, co umożliwia bezpośrednie sprawdzanie punktu kontrolnego lub eksportowanie SavedModels z obiektów Keras. Nie musisz koniecznie korzystać z API Keras Model.fit
, aby skorzystać z tych integracji.
Przeczytaj sekcję o transferze uczenia się i dostrajaniu w przewodniku Keras, aby dowiedzieć się, jak zebrać podzbiór odpowiednich zmiennych za pomocą Keras.
Połącz tf.data.Dataset
s i tf.function
Pakiet TensorFlow Datasets ( tfds
) zawiera narzędzia do ładowania predefiniowanych zestawów danych jako obiektów tf.data.Dataset
. W tym przykładzie możesz załadować zestaw danych MNIST za pomocą tfds
:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Następnie przygotuj dane do treningu:
- Skaluj ponownie każdy obraz.
- Potasuj kolejność przykładów.
- Zbieraj partie obrazów i etykiet.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Aby przykład był krótki, przytnij zbiór danych, aby zwracał tylko 5 partii:
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Użyj zwykłej iteracji Pythona do iteracji danych uczących, które mieszczą się w pamięci. W przeciwnym razie tf.data.Dataset
to najlepszy sposób na przesyłanie strumieniowe danych szkoleniowych z dysku. Zbiory danych to iterable (nie iteratory) i działają tak samo jak inne iterable Pythona w gorliwym wykonywaniu. Możesz w pełni wykorzystać asynchroniczne funkcje wstępnego pobierania/przesyłania strumieniowego zestawu danych, opakowując swój kod w tf.function
, który zastępuje iterację Pythona równoważnymi operacjami wykresu przy użyciu AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Jeśli korzystasz z interfejsu API Keras Model.fit
, nie musisz się martwić o iterację zestawu danych.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Użyj pętli treningowych Keras
Jeśli nie potrzebujesz niskopoziomowej kontroli nad procesem treningowym, zalecane jest użycie wbudowanych metod Keras fit
, evaluate
i predict
. Te metody zapewniają jednolity interfejs do uczenia modelu niezależnie od implementacji (sekwencyjnej, funkcjonalnej lub podklasy).
Zaletami tych metod są:
- Akceptują tablice Numpy, generatory Pythona i
tf.data.Datasets
. - Stosują regularyzację i straty aktywacyjne automatycznie.
- Obsługują
tf.distribute
, gdzie kod szkolenia pozostaje taki sam niezależnie od konfiguracji sprzętu . - Obsługują arbitralne nabytki jako straty i metryki.
- Obsługują one wywołania zwrotne, takie jak
tf.keras.callbacks.TensorBoard
i niestandardowe wywołania zwrotne. - Są wydajne, automatycznie przy użyciu wykresów TensorFlow.
Oto przykład uczenia modelu przy użyciu Dataset
. Aby dowiedzieć się, jak to działa, zapoznaj się z samouczkami .
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5 5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938 Epoch 2/5 2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969 Epoch 3/5 2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469 Epoch 4/5 2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688 Epoch 5/5 2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719 2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. 5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781 Loss 1.4552843570709229, Accuracy 0.578125 2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Dostosuj trening i napisz własną pętlę
Jeśli modele Keras działają dla Ciebie, ale potrzebujesz większej elastyczności i kontroli nad krokiem treningowym lub zewnętrznymi pętlami treningowymi, możesz wdrożyć własne kroki treningowe lub nawet całe pętle treningowe. Zobacz przewodnik Keras dotyczący dostosowywania fit
, aby dowiedzieć się więcej.
Możesz także zaimplementować wiele rzeczy jako tf.keras.callbacks.Callback
.
Ta metoda ma wiele zalet wspomnianych wcześniej , ale daje kontrolę nad krokiem pociągu, a nawet zewnętrzną pętlą.
Standardowa pętla treningowa składa się z trzech kroków:
- Wykonaj iterację przez generator Pythona lub
tf.data.Dataset
, aby uzyskać partie przykładów. - Użyj
tf.GradientTape
do zbierania gradientów. - Użyj jednego z
tf.keras.optimizers
, aby zastosować aktualizacje wagi do zmiennych modelu.
Pamiętać:
- Zawsze dołączaj argument
training
w metodziecall
podklas warstw i modeli. - Upewnij się, że wywołałeś model z poprawnie ustawionym argumentem
training
. - W zależności od użycia zmienne modelu mogą nie istnieć, dopóki model nie zostanie uruchomiony na partii danych.
- Musisz ręcznie poradzić sobie z takimi rzeczami, jak straty związane z regularyzacją modelu.
Nie ma potrzeby uruchamiania inicjatorów zmiennych ani dodawania ręcznych zależności sterowania. tf.function
obsługuje automatyczne zależności sterowania i inicjalizację zmiennych podczas tworzenia.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 0 2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 1 2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 2 2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Finished epoch 3 Finished epoch 4 2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Skorzystaj z tf.function
z przepływem sterowania w Pythonie
tf.function
umożliwia konwersję zależnego od danych przepływu sterowania na jego odpowiedniki w trybie wykresu, takie jak tf.cond
i tf.while_loop
.
Jednym z powszechnych miejsc, w których pojawia się przepływ sterowania zależny od danych, są modele sekwencyjne. tf.keras.layers.RNN
otacza komórkę RNN, umożliwiając statyczne lub dynamiczne rozwijanie cyklu. Na przykład możesz ponownie wdrożyć dynamiczne rozwijanie w następujący sposób.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Przeczytaj przewodnik tf.function
, aby uzyskać więcej informacji.
Wskaźniki i straty w nowym stylu
Metryki i straty to zarówno obiekty, które chętnie pracują, jak i w tf.function
s.
Obiekt straty można wywołać i oczekuje ( y_true
, y_pred
) jako argumentów:
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Używaj metryk do zbierania i wyświetlania danych
Możesz użyć tf.metrics
do agregowania danych i tf.summary
do dzienników podsumowań i przekierować je do autora za pomocą menedżera kontekstu. Podsumowania są emitowane bezpośrednio do piszącego, co oznacza, że musisz podać wartość step
w miejscu wywołania.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Użyj tf.metrics
do agregowania danych przed zarejestrowaniem ich jako podsumowań. Metryki są stanowe; gromadzą wartości i zwracają skumulowany wynik po wywołaniu metody result
(takiej jak Mean.result
). Wyczyść skumulowane wartości za pomocą Model.reset_states
.
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Wizualizuj wygenerowane podsumowania, wskazując TensorBoard na katalog dziennika podsumowań:
tensorboard --logdir /tmp/summaries
Użyj interfejsu API tf.summary
, aby zapisać dane podsumowujące do wizualizacji w TensorBoard. Aby uzyskać więcej informacji, przeczytaj przewodnik tf.summary
.
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 0 loss: 0.142 accuracy: 0.991 2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 1 loss: 0.125 accuracy: 0.997 2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 2 loss: 0.110 accuracy: 0.997 2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead. Epoch: 3 loss: 0.099 accuracy: 0.997 Epoch: 4 loss: 0.085 accuracy: 1.000 2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
Keras nazwy metryczne
Modele Keras są spójne pod względem obsługi nazw metryk. Gdy przekazujesz ciąg na liście metryk, ten właśnie ciąg jest używany jako name
metryki . Nazwy te są widoczne w obiekcie historii zwróconym przez model.fit
oraz w logach przekazanych do keras.callbacks
. jest ustawiony na ciąg znaków przekazany na liście metryk.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969 2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Debugowanie
Użyj szybkiego wykonania, aby uruchomić kod krok po kroku, aby sprawdzić kształty, typy danych i wartości. Niektóre interfejsy API, takie jak tf.function
, tf.keras
itp., są zaprojektowane do korzystania z wykonywania wykresów w celu zapewnienia wydajności i przenośności. Podczas debugowania użyj tf.config.run_functions_eagerly(True)
, aby użyć szybkiego wykonywania wewnątrz tego kodu.
Na przykład:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
Działa to również w modelach Keras i innych interfejsach API, które obsługują szybkie wykonanie:
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Uwagi:
Metody
tf.keras.Model
, takie jakfit
,tf.function
evaluate
maskąpredict
Używając
tf.keras.Model.compile
, ustawrun_eagerly = True
, aby wyłączyć logikęModel
przed zawijaniem wtf.function
.Użyj
tf.data.experimental.enable_debug_mode
, aby włączyć tryb debugowania dlatf.data
. Przeczytaj dokumentację interfejsu API , aby uzyskać więcej informacji.
Nie trzymaj tf.Tensors
w swoich obiektach
Te tensory mogą zostać utworzone w funkcji tf.function
lub w gorliwym kontekście, a te tensory zachowują się inaczej. Zawsze używaj tf.Tensor
s tylko dla wartości pośrednich.
Aby śledzić stan, użyj tf.Variable
s, ponieważ zawsze można ich używać z obu kontekstów. Przeczytaj przewodnik tf.Variable
, aby dowiedzieć się więcej.
Zasoby i dalsze czytanie
Przeczytaj przewodniki i samouczki TF2, aby dowiedzieć się więcej o korzystaniu z TF2.
Jeśli wcześniej używałeś TF1.x, zdecydowanie zaleca się migrację kodu do TF2. Przeczytaj przewodniki po migracji, aby dowiedzieć się więcej.