Veja no TensorFlow.org | Executar no Google Colab | Ver fonte no GitHub | Baixar caderno |
A frase "Salvar um modelo do TensorFlow" normalmente significa uma das duas coisas:
- Pontos de verificação, OU
- SavedModel.
Os pontos de verificação capturam o valor exato de todos os parâmetros (objetos tf.Variable
) usados por um modelo. Os pontos de verificação não contêm nenhuma descrição da computação definida pelo modelo e, portanto, normalmente são úteis apenas quando o código-fonte que usará os valores de parâmetro salvos estiver disponível.
O formato SavedModel, por outro lado, inclui uma descrição serializada da computação definida pelo modelo, além dos valores dos parâmetros (ponto de verificação). Modelos nesse formato são independentes do código-fonte que criou o modelo. Portanto, eles são adequados para implantação via TensorFlow Serving, TensorFlow Lite, TensorFlow.js ou programas em outras linguagens de programação (as APIs C, C++, Java, Go, Rust, C# etc. TensorFlow).
Este guia abrange APIs para escrever e ler pontos de verificação.
Configurar
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
Salvando das APIs de treinamento tf.keras
Consulte o guia tf.keras
sobre como salvar e restaurar arquivos .
tf.keras.Model.save_weights
salva um ponto de verificação do TensorFlow.
net.save_weights('easy_checkpoint')
Escrevendo pontos de verificação
O estado persistente de um modelo do TensorFlow é armazenado em objetos tf.Variable
. Eles podem ser construídos diretamente, mas geralmente são criados por meio de APIs de alto nível, como tf.keras.layers
ou tf.keras.Model
.
A maneira mais fácil de gerenciar variáveis é anexá-las a objetos Python e fazer referência a esses objetos.
As subclasses de tf.train.Checkpoint
, tf.keras.layers.Layer
e tf.keras.Model
rastreiam automaticamente as variáveis atribuídas a seus atributos. O exemplo a seguir constrói um modelo linear simples e, em seguida, grava pontos de verificação que contêm valores para todas as variáveis do modelo.
Você pode salvar facilmente um ponto de verificação de modelo com Model.save_weights
.
Ponto de verificação manual
Configurar
Para ajudar a demonstrar todos os recursos de tf.train.Checkpoint
, defina um conjunto de dados de brinquedo e uma etapa de otimização:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Criar os objetos de ponto de verificação
Use um objeto tf.train.Checkpoint
para criar manualmente um ponto de verificação, onde os objetos que você deseja verificar são definidos como atributos no objeto.
Um tf.train.CheckpointManager
também pode ser útil para gerenciar vários pontos de verificação.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
Treine e verifique o modelo
O loop de treinamento a seguir cria uma instância do modelo e de um otimizador e os reúne em um objeto tf.train.Checkpoint
. Ele chama a etapa de treinamento em um loop em cada lote de dados e grava periodicamente os pontos de verificação no disco.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
Restaurar e continuar o treinamento
Após o primeiro ciclo de treinamento, você pode passar por um novo modelo e gerente, mas continuar o treinamento exatamente de onde parou:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
O objeto tf.train.CheckpointManager
exclui pontos de verificação antigos. Acima está configurado para manter apenas os três checkpoints mais recentes.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
Esses caminhos, por exemplo, './tf_ckpts/ckpt-10'
, não são arquivos em disco. Em vez disso, eles são prefixos para um arquivo de index
e um ou mais arquivos de dados que contêm os valores das variáveis. Esses prefixos são agrupados em um único arquivo de checkpoint
de verificação ( './tf_ckpts/checkpoint'
) onde o CheckpointManager
salva seu estado.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
Mecânica de carregamento
O TensorFlow combina variáveis com valores de checkpoint percorrendo um gráfico direcionado com arestas nomeadas, começando pelo objeto que está sendo carregado. Nomes de arestas normalmente vêm de nomes de atributos em objetos, por exemplo, o "l1"
em self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
usa seus nomes de argumentos de palavra-chave, como no "step"
em tf.train.Checkpoint(step=...)
.
O gráfico de dependência do exemplo acima se parece com isso:
O otimizador está em vermelho, as variáveis regulares estão em azul e as variáveis de slot do otimizador estão em laranja. Os outros nós—por exemplo, representando o tf.train.Checkpoint
—estão em preto.
As variáveis de slot fazem parte do estado do otimizador, mas são criadas para uma variável específica. Por exemplo, as arestas 'm'
acima correspondem ao momento, que o otimizador Adam rastreia para cada variável. As variáveis de slot só são salvas em um ponto de verificação se a variável e o otimizador forem salvos, portanto, as bordas tracejadas.
Chamar restore
em um objeto tf.train.Checkpoint
enfileira as restaurações solicitadas, restaurando os valores das variáveis assim que houver um caminho correspondente do objeto Checkpoint
. Por exemplo, você pode carregar apenas o viés do modelo definido acima, reconstruindo um caminho para ele através da rede e da camada.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
O gráfico de dependência para esses novos objetos é um subgráfico muito menor do ponto de verificação maior que você escreveu acima. Ele inclui apenas o viés e um contador de salvamento que tf.train.Checkpoint
usa para numerar os pontos de verificação.
restore
retorna um objeto de status, que possui asserções opcionais. Todos os objetos criados no novo Checkpoint
foram restaurados, então status.assert_existing_objects_matched
passa.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
Existem muitos objetos no ponto de verificação que não correspondem, incluindo o kernel da camada e as variáveis do otimizador. status.assert_consumed
só passa se o ponto de verificação e o programa corresponderem exatamente e lançariam uma exceção aqui.
Restaurações adiadas
Objetos de Layer
no TensorFlow podem adiar a criação de variáveis para sua primeira chamada, quando as formas de entrada estiverem disponíveis. Por exemplo, a forma do kernel de uma camada Dense
depende das formas de entrada e saída da camada e, portanto, a forma de saída necessária como argumento do construtor não é informação suficiente para criar a variável por conta própria. Como chamar uma Layer
também lê o valor da variável, uma restauração deve ocorrer entre a criação da variável e seu primeiro uso.
Para suportar esse idioma, tf.train.Checkpoint
adia restaurações que ainda não possuem uma variável correspondente.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
Inspecionando manualmente os pontos de verificação
tf.train.load_checkpoint
retorna um CheckpointReader
que fornece acesso de nível inferior ao conteúdo do ponto de verificação. Ele contém mapeamentos da chave de cada variável para a forma e o dtype de cada variável no ponto de verificação. A chave de uma variável é o caminho do objeto, como nos gráficos exibidos acima.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
Então, se você estiver interessado no valor de net.l1.kernel
, você pode obter o valor com o seguinte código:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
Ele também fornece um método get_tensor
que permite inspecionar o valor de uma variável:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
Rastreamento de objetos
Os pontos de verificação salvam e restauram os valores dos objetos tf.Variable
"rastreando" qualquer variável ou objeto rastreável definido em um de seus atributos. Ao executar um salvamento, as variáveis são coletadas recursivamente de todos os objetos rastreados alcançáveis.
Assim como as atribuições diretas de atributos como self.l1 = tf.keras.layers.Dense(5)
, atribuir listas e dicionários a atributos rastreará seu conteúdo.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
Você pode observar objetos wrapper para listas e dicionários. Esses wrappers são versões passíveis de verificação das estruturas de dados subjacentes. Assim como o carregamento baseado em atributo, esses wrappers restauram o valor de uma variável assim que ela é adicionada ao contêiner.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
Objetos rastreáveis incluem tf.train.Checkpoint
, tf.Module
e suas subclasses (por exemplo, keras.layers.Layer
e keras.Model
) e contêineres Python reconhecidos:
-
dict
(ecollections.OrderedDict
) -
list
-
tuple
(ecollections.namedtuple
,typing.NamedTuple
)
Outros tipos de contêiner não são compatíveis , incluindo:
-
collections.defaultdict
-
set
Todos os outros objetos Python são ignorados , incluindo:
-
int
-
string
-
float
Resumo
Os objetos do TensorFlow fornecem um mecanismo automático fácil para salvar e restaurar os valores das variáveis que eles usam.