Punti di controllo della formazione

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza l'origine su GitHub Scarica quaderno

La frase "Salvataggio di un modello TensorFlow" in genere significa una delle due cose:

  1. Punti di controllo, O
  2. Modello salvato.

I checkpoint acquisiscono il valore esatto di tutti i parametri ( tf.Variable objects) utilizzati da un modello. I checkpoint non contengono alcuna descrizione del calcolo definito dal modello e quindi sono generalmente utili solo quando è disponibile il codice sorgente che utilizzerà i valori dei parametri salvati.

Il formato SavedModel invece include una descrizione serializzata del calcolo definito dal modello oltre ai valori dei parametri (checkpoint). I modelli in questo formato sono indipendenti dal codice sorgente che ha creato il modello. Sono quindi adatti per l'implementazione tramite TensorFlow Serving, TensorFlow Lite, TensorFlow.js o programmi in altri linguaggi di programmazione (le API di TensorFlow C, C++, Java, Go, Rust, C# ecc.).

Questa guida tratta le API per la scrittura e la lettura dei checkpoint.

Impostare

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Salvataggio dalle API di addestramento di tf.keras

Consulta la guida di tf.keras sul salvataggio e il ripristino.

tf.keras.Model.save_weights salva un checkpoint TensorFlow.

net.save_weights('easy_checkpoint')

Scrivere posti di blocco

Lo stato persistente di un modello TensorFlow è archiviato in oggetti tf.Variable . Questi possono essere costruiti direttamente, ma sono spesso creati tramite API di alto livello come tf.keras.layers o tf.keras.Model .

Il modo più semplice per gestire le variabili è collegarle agli oggetti Python, quindi fare riferimento a tali oggetti.

Le sottoclassi di tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model automaticamente le variabili assegnate ai loro attributi. L'esempio seguente costruisce un semplice modello lineare, quindi scrive checkpoint che contengono valori per tutte le variabili del modello.

Puoi facilmente salvare un checkpoint del modello con Model.save_weights .

Checkpoint manuale

Impostare

Per aiutare a dimostrare tutte le funzionalità di tf.train.Checkpoint , definisci un set di dati del giocattolo e una fase di ottimizzazione:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Crea gli oggetti checkpoint

Utilizzare un oggetto tf.train.Checkpoint per creare manualmente un checkpoint, in cui gli oggetti che si desidera controllare sono impostati come attributi sull'oggetto.

Un tf.train.CheckpointManager può anche essere utile per la gestione di più checkpoint.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Addestra e controlla il modello

Il ciclo di addestramento seguente crea un'istanza del modello e di un ottimizzatore, quindi li raccoglie in un oggetto tf.train.Checkpoint . Richiama la fase di addestramento in un ciclo su ogni batch di dati e scrive periodicamente i checkpoint su disco.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

Ripristina e continua l'allenamento

Dopo il primo ciclo di formazione puoi passare un nuovo modello e manager, ma riprendere la formazione esattamente da dove avevi interrotto:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

L'oggetto tf.train.CheckpointManager elimina i vecchi checkpoint. Sopra è configurato per mantenere solo i tre checkpoint più recenti.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Questi percorsi, ad esempio './tf_ckpts/ckpt-10' , non sono file su disco. Sono invece prefissi per un file di index e uno o più file di dati che contengono i valori delle variabili. Questi prefissi sono raggruppati in un unico file di checkpoint ( './tf_ckpts/checkpoint' ) in cui CheckpointManager salva il suo stato.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Meccanica di caricamento

TensorFlow abbina le variabili ai valori di checkpoint attraversando un grafico diretto con bordi denominati, a partire dall'oggetto caricato. I nomi dei bordi in genere derivano dai nomi degli attributi negli oggetti, ad esempio "l1" in self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint usa i nomi degli argomenti delle sue parole chiave, come in "step" in tf.train.Checkpoint(step=...) .

Il grafico delle dipendenze dall'esempio sopra è simile al seguente:

Visualizzazione del grafico delle dipendenze per il ciclo di formazione di esempio

L'ottimizzatore è in rosso, le variabili regolari sono in blu e le variabili dello slot dell'ottimizzatore sono in arancione. Gli altri nodi, ad esempio, che rappresentano il tf.train.Checkpoint , sono in nero.

Le variabili slot fanno parte dello stato dell'ottimizzatore, ma vengono create per una variabile specifica. Ad esempio, i bordi 'm' sopra corrispondono alla quantità di moto, che l'ottimizzatore Adam tiene traccia per ciascuna variabile. Le variabili slot vengono salvate in un checkpoint solo se la variabile e l'ottimizzatore vengono salvati entrambi, quindi i bordi tratteggiati.

La chiamata al restore su un oggetto tf.train.Checkpoint in coda i ripristini richiesti, ripristinando i valori delle variabili non appena esiste un percorso corrispondente dall'oggetto Checkpoint . Ad esempio, puoi caricare solo la distorsione dal modello che hai definito sopra ricostruendo un percorso ad esso attraverso la rete e il livello.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

Il grafico delle dipendenze per questi nuovi oggetti è un sottografo molto più piccolo del checkpoint più grande che hai scritto sopra. Include solo il bias e un contatore di salvataggio che tf.train.Checkpoint usa per numerare i checkpoint.

Visualizzazione di un sottografo per la variabile bias

restore restituisce un oggetto di stato, che ha asserzioni facoltative. Tutti gli oggetti creati nel nuovo Checkpoint sono stati ripristinati, quindi status.assert_existing_objects_matched passa.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

Ci sono molti oggetti nel checkpoint che non corrispondono, incluso il kernel del livello e le variabili dell'ottimizzatore. status.assert_consumed passa solo se il checkpoint e il programma corrispondono esattamente e genererebbe un'eccezione qui.

Restauri differiti

Gli oggetti di Layer in TensorFlow possono rinviare la creazione di variabili alla loro prima chiamata, quando sono disponibili forme di input. Ad esempio, la forma del kernel di un livello Dense dipende sia dalle forme di input che da quelle di output del livello, quindi la forma di output richiesta come argomento del costruttore non è un'informazione sufficiente per creare la variabile da sola. Poiché la chiamata a un Layer legge anche il valore della variabile, è necessario eseguire un ripristino tra la creazione della variabile e il suo primo utilizzo.

Per supportare questo linguaggio, tf.train.Checkpoint rinvia i ripristini che non hanno ancora una variabile corrispondente.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

Ispezione manuale dei checkpoint

tf.train.load_checkpoint restituisce un CheckpointReader che fornisce un accesso di livello inferiore al contenuto del checkpoint. Contiene le mappature dalla chiave di ogni variabile, alla forma e al dtype per ogni variabile nel checkpoint. La chiave di una variabile è il percorso dell'oggetto, come nei grafici mostrati sopra.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Quindi, se sei interessato al valore di net.l1.kernel puoi ottenere il valore con il seguente codice:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Fornisce anche un metodo get_tensor che consente di ispezionare il valore di una variabile:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

Tracciamento degli oggetti

I checkpoint salvano e ripristinano i valori degli oggetti tf.Variable "tracciando" qualsiasi variabile o oggetto tracciabile impostato in uno dei suoi attributi. Quando si esegue un salvataggio, le variabili vengono raccolte in modo ricorsivo da tutti gli oggetti tracciati raggiungibili.

Come per le assegnazioni di attributi diretti come self.l1 = tf.keras.layers.Dense(5) , l'assegnazione di elenchi e dizionari agli attributi ne traccia il contenuto.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Potresti notare oggetti wrapper per elenchi e dizionari. Questi wrapper sono versioni controllabili delle strutture di dati sottostanti. Proprio come il caricamento basato sugli attributi, questi wrapper ripristinano il valore di una variabile non appena viene aggiunta al contenitore.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Gli oggetti tracciabili includono tf.train.Checkpoint , tf.Module e le sue sottoclassi (ad esempio keras.layers.Layer e keras.Model ) e contenitori Python riconosciuti:

  • dict (e collections.OrderedDict . OrderedDict )
  • list
  • tuple (e collections.namedtuple , typing.NamedTuple )

Non sono supportati altri tipi di contenitori, inclusi:

  • collections.defaultdict
  • set

Tutti gli altri oggetti Python vengono ignorati , inclusi:

  • int
  • string
  • float

Riepilogo

Gli oggetti TensorFlow forniscono un semplice meccanismo automatico per salvare e ripristinare i valori delle variabili che utilizzano.