ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
วลี "การบันทึกโมเดล TensorFlow" โดยทั่วไปหมายถึงหนึ่งในสองสิ่ง:
- จุดตรวจ OR
- โมเดลที่บันทึกไว้
จุดตรวจจับค่าที่แน่นอนของพารามิเตอร์ทั้งหมด ( tf.Variable
อบเจ็กต์) ที่ใช้โดยโมเดล จุดตรวจไม่มีคำอธิบายใดๆ ของการคำนวณที่กำหนดโดยโมเดล และโดยทั่วไปแล้วจะมีประโยชน์ก็ต่อเมื่อมีซอร์สโค้ดที่จะใช้ค่าพารามิเตอร์ที่บันทึกไว้เท่านั้น
ในทางกลับกัน รูปแบบ SavedModel รวมคำอธิบายต่อเนื่องของการคำนวณที่กำหนดโดยโมเดล นอกเหนือจากค่าพารามิเตอร์ (จุดตรวจสอบ) โมเดลในรูปแบบนี้ไม่ขึ้นกับซอร์สโค้ดที่สร้างโมเดล ดังนั้นจึงเหมาะสำหรับการปรับใช้ผ่าน TensorFlow Serving, TensorFlow Lite, TensorFlow.js หรือโปรแกรมในภาษาการเขียนโปรแกรมอื่นๆ (C, C++, Java, Go, Rust, C# เป็นต้น TensorFlow API)
คู่มือนี้ครอบคลุม API สำหรับการเขียนและการอ่านจุดตรวจ
ติดตั้ง
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
กำลังบันทึกจาก API การฝึกอบรม tf.keras
ดูคู่มือ tf.keras
เกี่ยวกับการบันทึกและการกู้คืน
tf.keras.Model.save_weights
บันทึกจุดตรวจ TensorFlow
net.save_weights('easy_checkpoint')
เขียนด่าน
สถานะคงอยู่ของโมเดล TensorFlow ถูกเก็บไว้ใน tf.Variable
สิ่งเหล่านี้สามารถสร้างได้โดยตรง แต่มักจะสร้างผ่าน API ระดับสูง เช่น tf.keras.layers
หรือ tf.keras.Model
วิธีที่ง่ายที่สุดในการจัดการตัวแปรคือการแนบไปกับอ็อบเจ็กต์ Python แล้วอ้างอิงอ็อบเจ็กต์เหล่านั้น
คลาสย่อยของ tf.train.Checkpoint
, tf.keras.layers.Layer
และ tf.keras.Model
ติดตามตัวแปรที่กำหนดให้กับแอตทริบิวต์โดยอัตโนมัติ ตัวอย่างต่อไปนี้สร้างโมเดลเชิงเส้นอย่างง่าย จากนั้นจึงเขียนจุดตรวจสอบซึ่งมีค่าสำหรับตัวแปรทั้งหมดของโมเดล
คุณสามารถบันทึก model-checkpoint ได้อย่างง่ายดายด้วย Model.save_weights
จุดตรวจด้วยตนเอง
ติดตั้ง
เพื่อช่วยสาธิตคุณลักษณะทั้งหมดของ tf.train.Checkpoint
ให้กำหนดชุดข้อมูลของเล่นและขั้นตอนการเพิ่มประสิทธิภาพ:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
สร้างวัตถุจุดตรวจ
ใช้อ็อบเจ็กต์ tf.train.Checkpoint
เพื่อสร้างจุดตรวจด้วยตนเอง โดยที่ออบเจ็กต์ที่คุณต้องการตรวจสอบถูกตั้งค่าเป็นแอตทริบิวต์บนวัตถุ
tf.train.CheckpointManager
ยังมีประโยชน์สำหรับการจัดการจุดตรวจหลายจุด
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ฝึกและจุดตรวจโมเดล
ลูปการฝึกต่อไปนี้จะสร้างอินสแตนซ์ของโมเดลและเครื่องมือเพิ่มประสิทธิภาพ จากนั้นรวบรวมเป็นอ็อบเจ็กต์ tf.train.Checkpoint
โดยเรียกขั้นตอนการฝึกวนเป็นชุดของข้อมูลแต่ละชุด และเขียนจุดตรวจสอบลงดิสก์เป็นระยะ
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
ฟื้นฟูและฝึกต่อ
หลังจากรอบการฝึกครั้งแรก คุณสามารถส่งต่อโมเดลและผู้จัดการคนใหม่ได้ แต่รับการฝึกอบรมตรงจุดที่คุณค้างไว้:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
ออบเจ็กต์ tf.train.CheckpointManager
ลบจุดตรวจเก่า ด้านบนมีการกำหนดค่าให้เก็บเฉพาะจุดตรวจล่าสุดสามจุดเท่านั้น
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
เส้นทางเหล่านี้ เช่น './tf_ckpts/ckpt-10'
ไม่ใช่ไฟล์บนดิสก์ แต่เป็นคำนำหน้าสำหรับไฟล์ index
และไฟล์ข้อมูลอย่างน้อยหนึ่งไฟล์ที่มีค่าตัวแปร คำนำหน้าเหล่านี้จัดกลุ่มไว้ด้วยกันในไฟล์ checkpoint
เดียว ( './tf_ckpts/checkpoint'
) โดยที่ CheckpointManager
จะบันทึกสถานะ
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
กำลังโหลดกลศาสตร์
TensorFlow จะจับคู่ตัวแปรกับค่าจุดตรวจสอบโดยการสำรวจกราฟที่มีชื่อขอบ โดยเริ่มจากวัตถุที่กำลังโหลด ชื่อขอบมักจะมาจากชื่อแอตทริบิวต์ในวัตถุ เช่น "l1"
ใน self.l1 = tf.keras.layers.Dense(5)
tf.train.Checkpoint
ใช้ชื่ออาร์กิวเมนต์ของคีย์เวิร์ด เช่นเดียวกับใน "step"
ใน tf.train.Checkpoint(step=...)
กราฟการพึ่งพาจากตัวอย่างด้านบนมีลักษณะดังนี้:
เครื่องมือเพิ่มประสิทธิภาพจะเป็นสีแดง ตัวแปรปกติจะเป็นสีน้ำเงิน และตัวแปรช่องเครื่องมือเพิ่มประสิทธิภาพจะเป็นสีส้ม โหนดอื่นๆ เช่น แทน tf.train.Checkpoint
เป็นสีดำ
ตัวแปรสล็อตเป็นส่วนหนึ่งของสถานะของตัวเพิ่มประสิทธิภาพ แต่ถูกสร้างขึ้นสำหรับตัวแปรเฉพาะ ตัวอย่างเช่น ขอบ 'm'
ด้านบนสอดคล้องกับโมเมนตัม ซึ่งเครื่องมือเพิ่มประสิทธิภาพ Adam ติดตามสำหรับแต่ละตัวแปร ตัวแปรสล็อตจะถูกบันทึกในจุดตรวจสอบหากทั้งตัวแปรและตัวเพิ่มประสิทธิภาพจะถูกบันทึก ดังนั้นขอบที่เป็นเส้นประ
การเรียกการ restore
บนอ็อบเจ็กต์ tf.train.Checkpoint
จะเข้าคิวการคืนค่าที่ร้องขอ เรียกคืนค่าตัวแปรทันทีที่มีพาธที่ตรงกันจากออบเจ็กต์ Checkpoint
ตัวอย่างเช่น คุณสามารถโหลดเฉพาะความเอนเอียงจากโมเดลที่คุณกำหนดไว้ด้านบนโดยสร้างเส้นทางเดียวไปยังโมเดลดังกล่าวผ่านเครือข่ายและเลเยอร์
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
กราฟการพึ่งพาสำหรับออบเจ็กต์ใหม่เหล่านี้เป็นกราฟย่อยที่เล็กกว่ามากของจุดตรวจที่ใหญ่กว่าที่คุณเขียนไว้ด้านบน รวมเฉพาะอคติและตัวนับบันทึกที่ tf.train.Checkpoint
ใช้เพื่อนับจุดตรวจ
restore
ส่งคืนวัตถุสถานะซึ่งมีการยืนยันที่เป็นทางเลือก ออบเจ็กต์ทั้งหมดที่สร้างขึ้นใน Checkpoint
ใหม่ได้รับการฟื้นฟูแล้ว ดังนั้น status.assert_existing_objects_matched
ผ่าน
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
มีอ็อบเจ็กต์จำนวนมากในจุดตรวจที่ไม่ตรงกัน รวมถึงเคอร์เนลของเลเยอร์และตัวแปรของออปติไมเซอร์ status.assert_consumed
ผ่านก็ต่อเมื่อจุดตรวจและโปรแกรมตรงกันทุกประการ และจะโยนข้อยกเว้นที่นี่
การบูรณะที่รอการตัดบัญชี
Layer
อ็อบเจ็กต์ใน TensorFlow อาจเลื่อนการสร้างตัวแปรไปเป็นการโทรครั้งแรก เมื่อรูปร่างอินพุตพร้อมใช้งาน ตัวอย่างเช่น รูปร่างของเคอร์เนลของเลเยอร์ Dense
ขึ้นอยู่กับทั้งรูปร่างอินพุตและเอาต์พุตของเลเยอร์ ดังนั้นรูปร่างเอาต์พุตที่ต้องการในฐานะอาร์กิวเมนต์ตัวสร้างจึงไม่มีข้อมูลเพียงพอที่จะสร้างตัวแปรได้ด้วยตัวเอง เนื่องจากการเรียกใช้ Layer
ยังอ่านค่าของตัวแปรด้วย การคืนค่าจะต้องเกิดขึ้นระหว่างการสร้างตัวแปรและการใช้งานครั้งแรก
เพื่อรองรับสำนวนนี้ tf.train.Checkpoint
defers restores ซึ่งยังไม่มีตัวแปรที่ตรงกัน
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]ตัวยึดตำแหน่ง22
ตรวจสอบจุดตรวจด้วยตนเอง
tf.train.load_checkpoint
ส่งคืน CheckpointReader
ที่ให้การเข้าถึงเนื้อหาด่านที่ต่ำกว่า ประกอบด้วยการแมปจากคีย์ของตัวแปรแต่ละตัว ไปยังรูปร่างและ dtype ของตัวแปรแต่ละตัวในจุดตรวจสอบ คีย์ของตัวแปรคือเส้นทางของออบเจ็กต์ เช่นเดียวกับในกราฟที่แสดงด้านบน
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
ดังนั้น หากคุณสนใจในมูลค่าของ net.l1.kernel
คุณสามารถรับค่าโดยใช้รหัสต่อไปนี้:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
นอกจากนี้ยังมีวิธี get_tensor
ที่ให้คุณตรวจสอบค่าของตัวแปรได้:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
การติดตามวัตถุ
จุดตรวจสอบจะบันทึกและกู้คืนค่าของ tf.Variable
โดย "ติดตาม" ตัวแปรหรืออ็อบเจ็กต์ที่ติดตามได้ซึ่งตั้งค่าไว้ในแอตทริบิวต์ใดแอตทริบิวต์หนึ่ง เมื่อดำเนินการบันทึก ตัวแปรจะถูกรวบรวมแบบเรียกซ้ำจากออบเจ็กต์ที่ติดตามที่เข้าถึงได้ทั้งหมด
เช่นเดียวกับการกำหนดแอตทริบิวต์โดยตรง เช่น self.l1 = tf.keras.layers.Dense(5)
การกำหนดรายการและพจนานุกรมให้กับแอตทริบิวต์จะติดตามเนื้อหา
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
คุณอาจสังเกตเห็นออบเจ็กต์ของแรปเปอร์สำหรับรายการและพจนานุกรม Wrapper เหล่านี้เป็นเวอร์ชันที่ตรวจสอบได้ของโครงสร้างข้อมูลพื้นฐาน เช่นเดียวกับการโหลดตามแอตทริบิวต์ Wrapper เหล่านี้จะคืนค่าของตัวแปรทันทีที่เพิ่มลงในคอนเทนเนอร์
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
ออบเจ็กต์ที่ติดตามได้รวมถึง tf.train.Checkpoint
, tf.Module
และคลาสย่อย (เช่น keras.layers.Layer
และ keras.Model
) และคอนเทนเนอร์ Python ที่รู้จัก:
-
dict
(และcollections.OrderedDict
) -
list
-
tuple
(และcollections.namedtuple
typing.NamedTuple
)
ไม่รองรับ คอนเทนเนอร์ประเภทอื่นๆ ซึ่งรวมถึง:
-
collections.defaultdict
-
set
วัตถุ Python อื่น ๆ ทั้งหมดจะ ถูกละเว้น รวมถึง:
-
int
-
string
-
float
สรุป
ออบเจ็กต์ TensorFlow ให้กลไกอัตโนมัติที่ง่ายดายสำหรับการบันทึกและกู้คืนค่าของตัวแปรที่ใช้