Pontos de verificação de treinamento

Veja no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

A frase "Salvar um modelo do TensorFlow" normalmente significa uma das duas coisas:

  1. Pontos de verificação, OU
  2. SavedModel.

Os pontos de verificação capturam o valor exato de todos os parâmetros (objetos tf.Variable ) usados ​​por um modelo. Os pontos de verificação não contêm nenhuma descrição da computação definida pelo modelo e, portanto, normalmente são úteis apenas quando o código-fonte que usará os valores de parâmetro salvos estiver disponível.

O formato SavedModel, por outro lado, inclui uma descrição serializada da computação definida pelo modelo, além dos valores dos parâmetros (ponto de verificação). Modelos nesse formato são independentes do código-fonte que criou o modelo. Portanto, eles são adequados para implantação via TensorFlow Serving, TensorFlow Lite, TensorFlow.js ou programas em outras linguagens de programação (as APIs C, C++, Java, Go, Rust, C# etc. TensorFlow).

Este guia abrange APIs para escrever e ler pontos de verificação.

Configurar

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Salvando das APIs de treinamento tf.keras

Consulte o guia tf.keras sobre como salvar e restaurar arquivos .

tf.keras.Model.save_weights salva um ponto de verificação do TensorFlow.

net.save_weights('easy_checkpoint')

Escrevendo pontos de verificação

O estado persistente de um modelo do TensorFlow é armazenado em objetos tf.Variable . Eles podem ser construídos diretamente, mas geralmente são criados por meio de APIs de alto nível, como tf.keras.layers ou tf.keras.Model .

A maneira mais fácil de gerenciar variáveis ​​é anexá-las a objetos Python e fazer referência a esses objetos.

As subclasses de tf.train.Checkpoint , tf.keras.layers.Layer e tf.keras.Model rastreiam automaticamente as variáveis ​​atribuídas a seus atributos. O exemplo a seguir constrói um modelo linear simples e, em seguida, grava pontos de verificação que contêm valores para todas as variáveis ​​do modelo.

Você pode salvar facilmente um ponto de verificação de modelo com Model.save_weights .

Ponto de verificação manual

Configurar

Para ajudar a demonstrar todos os recursos de tf.train.Checkpoint , defina um conjunto de dados de brinquedo e uma etapa de otimização:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Criar os objetos de ponto de verificação

Use um objeto tf.train.Checkpoint para criar manualmente um ponto de verificação, onde os objetos que você deseja verificar são definidos como atributos no objeto.

Um tf.train.CheckpointManager também pode ser útil para gerenciar vários pontos de verificação.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Treine e verifique o modelo

O loop de treinamento a seguir cria uma instância do modelo e de um otimizador e os reúne em um objeto tf.train.Checkpoint . Ele chama a etapa de treinamento em um loop em cada lote de dados e grava periodicamente os pontos de verificação no disco.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 31.27
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 24.68
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 18.12
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 11.65
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 5.39

Restaurar e continuar o treinamento

Após o primeiro ciclo de treinamento, você pode passar por um novo modelo e gerente, mas continuar o treinamento exatamente de onde parou:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.50
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.27
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.56
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.70
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.35

O objeto tf.train.CheckpointManager exclui pontos de verificação antigos. Acima está configurado para manter apenas os três checkpoints mais recentes.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Esses caminhos, por exemplo, './tf_ckpts/ckpt-10' , não são arquivos em disco. Em vez disso, eles são prefixos para um arquivo de index e um ou mais arquivos de dados que contêm os valores das variáveis. Esses prefixos são agrupados em um único arquivo de checkpoint de verificação ( './tf_ckpts/checkpoint' ) onde o CheckpointManager salva seu estado.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mecânica de carregamento

O TensorFlow combina variáveis ​​com valores de checkpoint percorrendo um gráfico direcionado com arestas nomeadas, começando pelo objeto que está sendo carregado. Nomes de arestas normalmente vêm de nomes de atributos em objetos, por exemplo, o "l1" em self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint usa seus nomes de argumentos de palavra-chave, como no "step" em tf.train.Checkpoint(step=...) .

O gráfico de dependência do exemplo acima se parece com isso:

Visualização do gráfico de dependência para o loop de treinamento de exemplo

O otimizador está em vermelho, as variáveis ​​regulares estão em azul e as variáveis ​​de slot do otimizador estão em laranja. Os outros nós—por exemplo, representando o tf.train.Checkpoint —estão em preto.

As variáveis ​​de slot fazem parte do estado do otimizador, mas são criadas para uma variável específica. Por exemplo, as arestas 'm' acima correspondem ao momento, que o otimizador Adam rastreia para cada variável. As variáveis ​​de slot só são salvas em um ponto de verificação se a variável e o otimizador forem salvos, portanto, as bordas tracejadas.

Chamar restore em um objeto tf.train.Checkpoint enfileira as restaurações solicitadas, restaurando os valores das variáveis ​​assim que houver um caminho correspondente do objeto Checkpoint . Por exemplo, você pode carregar apenas o viés do modelo definido acima, reconstruindo um caminho para ele através da rede e da camada.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.7209885 3.7588918 4.421351  4.1466427 4.0712557]

O gráfico de dependência para esses novos objetos é um subgráfico muito menor do ponto de verificação maior que você escreveu acima. Ele inclui apenas o viés e um contador de salvamento que tf.train.Checkpoint usa para numerar os pontos de verificação.

Visualização de um subgráfico para a variável de viés

restore retorna um objeto de status, que possui asserções opcionais. Todos os objetos criados no novo Checkpoint foram restaurados, então status.assert_existing_objects_matched passa.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>

Existem muitos objetos no ponto de verificação que não correspondem, incluindo o kernel da camada e as variáveis ​​do otimizador. status.assert_consumed só passa se o ponto de verificação e o programa corresponderem exatamente e lançariam uma exceção aqui.

Restaurações adiadas

Objetos de Layer no TensorFlow podem adiar a criação de variáveis ​​para sua primeira chamada, quando as formas de entrada estiverem disponíveis. Por exemplo, a forma do kernel de uma camada Dense depende das formas de entrada e saída da camada e, portanto, a forma de saída necessária como argumento do construtor não é informação suficiente para criar a variável por conta própria. Como chamar uma Layer também lê o valor da variável, uma restauração deve ocorrer entre a criação da variável e seu primeiro uso.

Para suportar esse idioma, tf.train.Checkpoint adia restaurações que ainda não possuem uma variável correspondente.

deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5854754 4.607731  4.649179  4.8474874 5.121    ]]

Inspecionando manualmente os pontos de verificação

tf.train.load_checkpoint retorna um CheckpointReader que fornece acesso de nível inferior ao conteúdo do ponto de verificação. Ele contém mapeamentos da chave de cada variável para a forma e o dtype de cada variável no ponto de verificação. A chave de uma variável é o caminho do objeto, como nos gráficos exibidos acima.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Então, se você estiver interessado no valor de net.l1.kernel , você pode obter o valor com o seguinte código:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Ele também fornece um método get_tensor que permite inspecionar o valor de uma variável:

reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121    ]],
      dtype=float32)

Rastreamento de objetos

Os pontos de verificação salvam e restauram os valores dos objetos tf.Variable "rastreando" qualquer variável ou objeto rastreável definido em um de seus atributos. Ao executar um salvamento, as variáveis ​​são coletadas recursivamente de todos os objetos rastreados alcançáveis.

Assim como as atribuições diretas de atributos como self.l1 = tf.keras.layers.Dense(5) , atribuir listas e dicionários a atributos rastreará seu conteúdo.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Você pode observar objetos wrapper para listas e dicionários. Esses wrappers são versões passíveis de verificação das estruturas de dados subjacentes. Assim como o carregamento baseado em atributo, esses wrappers restauram o valor de uma variável assim que ela é adicionada ao contêiner.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Objetos rastreáveis ​​incluem tf.train.Checkpoint , tf.Module e suas subclasses (por exemplo, keras.layers.Layer e keras.Model ) e contêineres Python reconhecidos:

  • dict (e collections.OrderedDict )
  • list
  • tuple (e collections.namedtuple , typing.NamedTuple )

Outros tipos de contêiner não são compatíveis , incluindo:

  • collections.defaultdict
  • set

Todos os outros objetos Python são ignorados , incluindo:

  • int
  • string
  • float

Resumo

Os objetos do TensorFlow fornecem um mecanismo automático fácil para salvar e restaurar os valores das variáveis ​​que eles usam.