Zobacz na TensorFlow.org | Uruchom w Google Colab | Wyświetl źródło na GitHub | Pobierz notatnik |
Wyrażenie „Zapisywanie modelu TensorFlow” zazwyczaj oznacza jedną z dwóch rzeczy:
- Punkty kontrolne, OR
- Zapisany Model.
Punkty kontrolne przechwytują dokładną wartość wszystkich parametrów (obiekty tf.Variable
) używanych przez model. Punkty kontrolne nie zawierają żadnego opisu obliczeń zdefiniowanych przez model i dlatego są zwykle przydatne tylko wtedy, gdy dostępny jest kod źródłowy, który będzie używał zapisanych wartości parametrów.
Z drugiej strony format SavedModel zawiera zserializowany opis obliczeń zdefiniowanych przez model oprócz wartości parametrów (punkt kontrolny). Modele w tym formacie są niezależne od kodu źródłowego, który utworzył model. Dzięki temu nadają się do wdrożenia za pośrednictwem TensorFlow Serving, TensorFlow Lite, TensorFlow.js lub programów w innych językach programowania (C, C++, Java, Go, Rust, C# itp. API TensorFlow).
Ten przewodnik obejmuje interfejsy API do zapisywania i odczytywania punktów kontrolnych.
Ustawiać
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
Oszczędzanie z interfejsów API szkoleniowych tf.keras
Zobacz przewodnik tf.keras
dotyczący zapisywania i przywracania .
tf.keras.Model.save_weights
zapisuje punkt kontrolny TensorFlow.
net.save_weights('easy_checkpoint')
Pisanie punktów kontrolnych
Trwały stan modelu TensorFlow jest przechowywany w obiektach tf.Variable
. Można je konstruować bezpośrednio, ale często są one tworzone za pomocą interfejsów API wysokiego poziomu, takich jak tf.keras.layers
lub tf.keras.Model
.
Najłatwiejszym sposobem zarządzania zmiennymi jest dołączanie ich do obiektów Pythona, a następnie odwoływanie się do tych obiektów.
Podklasy tf.train.Checkpoint
, tf.keras.layers.Layer
i tf.keras.Model
automatycznie śledzą zmienne przypisane do ich atrybutów. Poniższy przykład konstruuje prosty model liniowy, a następnie zapisuje punkty kontrolne, które zawierają wartości dla wszystkich zmiennych modelu.
Możesz łatwo zapisać punkt kontrolny modelu za pomocą Model.save_weights
.
Ręczne punkty kontrolne
Ustawiać
Aby pomóc zademonstrować wszystkie funkcje tf.train.Checkpoint
, zdefiniuj zestaw danych zabawek i krok optymalizacji:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Utwórz obiekty punktów kontrolnych
Użyj obiektu tf.train.Checkpoint
, aby ręcznie utworzyć punkt kontrolny, w którym obiekty, które chcesz sprawdzić, są ustawione jako atrybuty obiektu.
tf.train.CheckpointManager
może być również pomocny w zarządzaniu wieloma punktami kontrolnymi.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
Trenuj i sprawdzaj model
Poniższa pętla szkoleniowa tworzy instancję modelu i optymalizatora, a następnie gromadzi je w obiekcie tf.train.Checkpoint
. Wywołuje etap uczenia w pętli na każdej partii danych i okresowo zapisuje punkty kontrolne na dysku.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
Przywróć i kontynuuj trening
Po pierwszym cyklu szkoleniowym możesz przekazać nowy model i menedżera, ale rozpocznij szkolenie dokładnie w miejscu, w którym je przerwałeś:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
Obiekt tf.train.CheckpointManager
usuwa stare punkty kontrolne. Powyżej jest skonfigurowany tak, aby zachować tylko trzy najnowsze punkty kontrolne.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
Te ścieżki, np './tf_ckpts/ckpt-10'
, nie są plikami na dysku. Zamiast tego są prefiksami dla pliku index
i jednego lub więcej plików danych, które zawierają wartości zmiennych. Te prefiksy są zgrupowane w pojedynczym pliku checkpoint
( './tf_ckpts/checkpoint'
), w którym CheckpointManager
zapisuje swój stan.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
Mechanika ładowania
TensorFlow dopasowuje zmienne do wartości w punktach kontrolnych, przemierzając ukierunkowany graf z nazwanymi krawędziami, zaczynając od ładowanego obiektu. Nazwy krawędzi zazwyczaj pochodzą od nazw atrybutów w obiektach, na przykład "l1"
w self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
używa swoich nazw argumentów słów kluczowych, jak w "step"
w tf.train.Checkpoint(step=...)
.
Wykres zależności z powyższego przykładu wygląda tak:
Optymalizator jest w kolorze czerwonym, zwykłe zmienne w kolorze niebieskim, a zmienne w boksie optymalizatora w kolorze pomarańczowym. Pozostałe węzły — na przykład reprezentujące tf.train.Checkpoint
— są czarne.
Zmienne przedziałów są częścią stanu optymalizatora, ale są tworzone dla określonej zmiennej. Na przykład, krawędzie 'm'
powyżej odpowiadają pędowi, który optymalizator Adam śledzi dla każdej zmiennej. Zmienne szczelin są zapisywane w punkcie kontrolnym tylko wtedy, gdy zmienna i optymalizator zostałyby zapisane, a więc krawędzie przerywane.
Wywołanie restore
na obiekcie tf.train.Checkpoint
kolejkuje żądane przywracania, przywracając wartości zmiennych, gdy tylko pojawi się pasująca ścieżka z obiektu Checkpoint
. Na przykład można załadować tylko odchylenie z modelu zdefiniowanego powyżej, rekonstruując jedną ścieżkę do niego przez sieć i warstwę.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
Wykres zależności dla tych nowych obiektów jest znacznie mniejszym podgrafem większego punktu kontrolnego, który napisałeś powyżej. Zawiera tylko odchylenie i licznik zapisywania, których tf.train.Checkpoint
używa do numerowania punktów kontrolnych.
restore
zwraca obiekt statusu, który ma opcjonalne asercje. Wszystkie obiekty utworzone w nowym punkcie Checkpoint
zostały przywrócone, więc status.assert_existing_objects_matched
przechodzi.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
W punkcie kontrolnym znajduje się wiele obiektów, które nie pasują, w tym jądro warstwy i zmienne optymalizatora. status.assert_consumed
przechodzi tylko wtedy, gdy punkt kontrolny i program dokładnie pasują i zgłosi tutaj wyjątek.
Przywrócenia odroczone
Obiekty Layer
w TensorFlow mogą odroczyć tworzenie zmiennych do ich pierwszego wywołania, gdy dostępne są kształty wejściowe. Na przykład kształt jądra warstwy Dense
zależy zarówno od kształtów wejściowych, jak i wyjściowych warstwy, dlatego kształt wyjściowy wymagany jako argument konstruktora nie jest wystarczającą informacją do samodzielnego utworzenia zmiennej. Ponieważ wywołanie Layer
również odczytuje wartość zmiennej, przywrócenie musi nastąpić między utworzeniem zmiennej a jej pierwszym użyciem.
Aby wesprzeć ten idiom, tf.train.Checkpoint
odracza przywracanie, które nie ma jeszcze pasującej zmiennej.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
Ręczne sprawdzanie punktów kontrolnych
tf.train.load_checkpoint
zwraca CheckpointReader
, który daje niższy poziom dostępu do zawartości punktu kontrolnego. Zawiera mapowania z klucza każdej zmiennej do kształtu i typu d dla każdej zmiennej w punkcie kontrolnym. Kluczem do zmiennej jest ścieżka jej obiektu, tak jak na powyższych wykresach.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
Jeśli więc interesuje Cię wartość net.l1.kernel
, możesz uzyskać wartość za pomocą następującego kodu:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
Zapewnia również metodę get_tensor
pozwalającą na sprawdzenie wartości zmiennej:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
Śledzenie obiektów
Punkty kontrolne zapisują i przywracają wartości obiektów tf.Variable
poprzez „śledzenie” dowolnej zmiennej lub śledzonego obiektu ustawionego w jednym z jego atrybutów. Podczas wykonywania zapisu zmienne są zbierane rekursywnie ze wszystkich osiągalnych śledzonych obiektów.
Podobnie jak w przypadku bezpośredniego przypisania atrybutów, takiego jak self.l1 = tf.keras.layers.Dense(5)
, przypisanie list i słowników do atrybutów będzie śledzić ich zawartość.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
Możesz zauważyć opakowujące obiekty dla list i słowników. Opakowania te są możliwymi do sprawdzenia wersjami podstawowych struktur danych. Podobnie jak ładowanie oparte na atrybutach, te opakowania przywracają wartość zmiennej zaraz po jej dodaniu do kontenera.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
Obiekty, które można śledzić obejmują tf.train.Checkpoint
, tf.Module
i ich podklasy (np keras.layers.Layer
i keras.Model
) oraz rozpoznane kontenery Pythona:
-
dict
(icollections.OrderedDict
) -
list
-
tuple
(icollections.namedtuple
,typing.NamedTuple
)
Inne typy kontenerów nie są obsługiwane , w tym:
-
collections.defaultdict
-
set
Wszystkie inne obiekty Pythona są ignorowane , w tym:
-
int
-
string
-
float
Streszczenie
Obiekty TensorFlow zapewniają łatwy automatyczny mechanizm zapisywania i przywracania wartości zmiennych, których używają.