عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
عادةً ما تعني عبارة "حفظ نموذج TensorFlow" أحد أمرين:
- نقاط التفتيش ، أو
- نموذج.
تلتقط نقاط التحقق القيمة الدقيقة لجميع المعلمات (الكائنات tf.Variable
) المستخدمة بواسطة النموذج. لا تحتوي نقاط التحقق على أي وصف للحساب المحدد بواسطة النموذج ، وبالتالي فهي مفيدة فقط عندما يتوفر كود المصدر الذي سيستخدم قيم المعلمات المحفوظة.
من ناحية أخرى ، يتضمن تنسيق SavedModel وصفًا متسلسلًا للحساب المحدد بواسطة النموذج بالإضافة إلى قيم المعلمات (نقطة التحقق). النماذج في هذا التنسيق مستقلة عن التعليمات البرمجية المصدر التي أنشأت النموذج. وبالتالي فهي مناسبة للنشر عبر TensorFlow Serving أو TensorFlow Lite أو TensorFlow.js أو البرامج بلغات البرمجة الأخرى (C ، C ++ ، Java ، Go ، Rust ، C # إلخ. TensorFlow APIs).
يغطي هذا الدليل واجهات برمجة التطبيقات (APIs) لكتابة وقراءة نقاط التفتيش.
يثبت
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
الحفظ من واجهات برمجة تطبيقات التدريب tf.keras
راجع دليل tf.keras
حول الحفظ والاستعادة.
يحفظ tf.keras.Model.save_weights
نقطة تفتيش TensorFlow.
net.save_weights('easy_checkpoint')
كتابة نقاط التفتيش
يتم تخزين الحالة المستمرة لنموذج tf.Variable
في كائنات متغيرة tf. يمكن إنشاء هذه بشكل مباشر ، ولكن غالبًا ما يتم إنشاؤها من خلال واجهات برمجة تطبيقات عالية المستوى مثل tf.keras.layers
أو tf.keras.Model
.
أسهل طريقة لإدارة المتغيرات هي إرفاقها بكائنات بايثون ، ثم الرجوع إلى تلك الكائنات.
تتتبع الفئات الفرعية لـ tf.train.Checkpoint
و tf.keras.layers.Layer
و tf.keras.Model المتغيرات المخصصة tf.keras.Model
تلقائيًا. يُنشئ المثال التالي نموذجًا خطيًا بسيطًا ، ثم يكتب نقاط التحقق التي تحتوي على قيم لجميع متغيرات النموذج.
يمكنك بسهولة حفظ نقطة فحص النموذج باستخدام Model.save_weights
.
التفتيش اليدوي
يثبت
للمساعدة في توضيح جميع ميزات tf.train.Checkpoint
، حدد مجموعة بيانات اللعبة وخطوة التحسين:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
إنشاء كائنات نقطة التفتيش
استخدم كائن tf.train.Checkpoint
لإنشاء نقطة تحقق يدويًا ، حيث يتم تعيين الكائنات التي تريد تحديدها كسمات على الكائن.
يمكن أن يكون tf.train.CheckpointManager
أيضًا في إدارة نقاط التفتيش المتعددة.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
تدريب وفحص النموذج
تُنشئ حلقة التدريب التالية مثيلاً للنموذج والمحسِّن ، ثم تجمعهما في كائن tf.train.Checkpoint
. يستدعي خطوة التدريب في حلقة على كل دفعة من البيانات ، ويكتب بشكل دوري نقاط التفتيش على القرص.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
استعادة ومواصلة التدريب
بعد الدورة التدريبية الأولى ، يمكنك اجتياز نموذج ومدير جديدين ، ولكن يمكنك متابعة التدريب من حيث توقفت تمامًا:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
يحذف كائن tf.train.CheckpointManager
نقاط التحقق القديمة. تم تكوينه أعلاه للاحتفاظ بنقاط التفتيش الثلاثة الأخيرة فقط.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
هذه المسارات ، مثل './tf_ckpts/ckpt-10'
، ليست ملفات على القرص. بدلاً من ذلك ، فهي بادئات لملف index
وملف بيانات واحد أو أكثر يحتوي على القيم المتغيرة. يتم تجميع هذه البادئات معًا في ملف checkpoint
تحقق واحد ( './tf_ckpts/checkpoint'
) حيث يحفظ CheckpointManager
حالته.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
ميكانيكا التحميل
يطابق TensorFlow المتغيرات مع القيم المحددة من خلال اجتياز الرسم البياني الموجه ذي الحواف المسماة ، بدءًا من الكائن الذي يتم تحميله. تأتي أسماء الحواف عادةً من أسماء السمات في الكائنات ، على سبيل المثال "l1"
في self.l1 = tf.keras.layers.Dense(5)
. يستخدم tf.train.Checkpoint
أسماء وسيطات الكلمات الرئيسية الخاصة به ، كما في "step"
في tf.train.Checkpoint(step=...)
.
يبدو الرسم البياني للتبعية من المثال أعلاه كما يلي:
المحسن باللون الأحمر ، والمتغيرات العادية باللون الأزرق ، ومتغيرات فتحة المحسن باللون البرتقالي. العقد الأخرى - على سبيل المثال ، التي تمثل tf.train.Checkpoint
- هي باللون الأسود.
تعد متغيرات الفتحة جزءًا من حالة المحسن ، ولكن يتم إنشاؤها لمتغير معين. على سبيل المثال ، تتوافق حواف 'm'
أعلاه مع الزخم ، الذي يتتبعه مُحسِّن آدم لكل متغير. يتم حفظ متغيرات الفتحة في نقطة فحص فقط إذا تم حفظ المتغير والمحسن ، وبالتالي الحواف المتقطعة.
يؤدي استدعاء restore
على tf.train.Checkpoint
الكائن في قائمة انتظار عمليات الاستعادة المطلوبة ، واستعادة القيم المتغيرة بمجرد وجود مسار مطابق من كائن Checkpoint
. على سبيل المثال ، يمكنك تحميل التحيز فقط من النموذج الذي حددته أعلاه عن طريق إعادة بناء مسار واحد إليه عبر الشبكة والطبقة.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
الرسم البياني للتبعية لهذه الكائنات الجديدة هو رسم بياني فرعي أصغر بكثير لنقطة التحقق الأكبر التي كتبتها أعلاه. يتضمن فقط التحيز وعداد الحفظ الذي يستخدمه tf.train.Checkpoint
نقاط التفتيش.
restore
كائن الحالة ، الذي يحتوي على تأكيدات اختيارية. تمت استعادة جميع الكائنات التي تم إنشاؤها في Checkpoint
التحقق الجديدة ، لذا فإن status.assert_existing_objects_matched
يمر.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
هناك العديد من الكائنات في نقطة التحقق غير المتطابقة ، بما في ذلك نواة الطبقة ومتغيرات المحسن. status.assert_consumed
يمر فقط إذا كانت نقطة التفتيش والبرنامج متطابقتين تمامًا ، وسوف يطرح استثناء هنا.
الترميمات المؤجلة
قد تؤجل كائنات Layer
في TensorFlow إنشاء المتغيرات لاستدعائها الأول ، عندما تكون أشكال الإدخال متاحة. على سبيل المثال ، يعتمد شكل نواة الطبقة Dense
على كلٍ من أشكال الإدخال والإخراج للطبقة ، وبالتالي فإن شكل الإخراج المطلوب كوسيطة مُنشئ ليس معلومات كافية لإنشاء المتغير بمفرده. نظرًا لأن استدعاء Layer
يقرأ أيضًا قيمة المتغير ، يجب أن تحدث استعادة بين إنشاء المتغير واستخدامه لأول مرة.
لدعم هذا المصطلح ، يستعيد tf.train.Checkpoint
التي لا تحتوي بعد على متغير مطابق.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
التفتيش اليدوي على نقاط التفتيش
tf.train.load_checkpoint
بإرجاع CheckpointReader
الذي يوفر وصولاً منخفض المستوى لمحتويات نقطة التفتيش. يحتوي على تعيينات من مفتاح كل متغير ، إلى الشكل والنوع لكل متغير في نقطة التحقق. مفتاح المتغير هو مسار الكائن الخاص به ، كما هو الحال في الرسوم البيانية المعروضة أعلاه.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
لذلك إذا كنت مهتمًا بقيمة net.l1.kernel
، يمكنك الحصول على القيمة باستخدام الكود التالي:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
يوفر أيضًا طريقة get_tensor
تسمح لك بفحص قيمة المتغير:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
تتبع الكائن
تقوم نقاط التحقق بحفظ واستعادة قيم tf.Variable
الكائنات عن طريق "تتبع" أي متغير أو كائن قابل للتتبع تم تعيينه في إحدى سماته. عند تنفيذ حفظ ، يتم جمع المتغيرات بشكل متكرر من جميع الكائنات المتعقبة التي يمكن الوصول إليها.
كما هو الحال مع تعيينات السمات المباشرة مثل self.l1 = tf.keras.layers.Dense(5)
، فإن تخصيص القوائم والقواميس للسمات سيؤدي إلى تتبع محتوياتها.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
قد تلاحظ كائنات مجمعة للقوائم والقواميس. هذه الأغلفة هي إصدارات يمكن التحقق منها من هياكل البيانات الأساسية. تمامًا مثل التحميل المستند إلى السمة ، تستعيد هذه الأغلفة قيمة المتغير بمجرد إضافته إلى الحاوية.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
تتضمن الكائنات tf.train.Checkpoint
و tf.Module
الفرعية (مثل keras.layers.Layer
و keras.Model
) وحاويات Python المعترف بها:
-
dict
(collections.OrderedDict
. -
list
-
tuple
(collections.namedtuple
. التي تسمى tuple ، الكتابة.typing.NamedTuple
)
أنواع الحاويات الأخرى غير مدعومة ، بما في ذلك:
-
collections.defaultdict
-
set
يتم تجاهل جميع كائنات Python الأخرى ، بما في ذلك:
-
int
-
string
-
float
ملخص
توفر كائنات TensorFlow آلية تلقائية سهلة لحفظ واستعادة قيم المتغيرات التي تستخدمها.