ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
สามารถบันทึกความคืบหน้าของแบบจำลองได้ในระหว่างและหลังการฝึก ซึ่งหมายความว่าโมเดลสามารถกลับมาทำงานต่อจากที่ค้างไว้และหลีกเลี่ยงการฝึกอบรมที่ยาวนาน การบันทึกยังหมายความว่าคุณสามารถแบ่งปันแบบจำลองของคุณและคนอื่นๆ สามารถสร้างงานของคุณขึ้นมาใหม่ได้ เมื่อเผยแพร่แบบจำลองและเทคนิคการวิจัย ผู้ปฏิบัติงานการเรียนรู้ของเครื่องส่วนใหญ่จะแบ่งปัน:
- รหัสเพื่อสร้างแบบจำลองและ
- ตุ้มน้ำหนักที่ฝึกหรือพารามิเตอร์สำหรับรุ่น
การแบ่งปันข้อมูลนี้จะช่วยให้ผู้อื่นเข้าใจวิธีการทำงานของแบบจำลองและลองใช้ข้อมูลใหม่ด้วยตนเอง
ตัวเลือก
มีหลายวิธีในการบันทึกโมเดล TensorFlow ขึ้นอยู่กับ API ที่คุณใช้ คู่มือนี้ใช้ tf.keras ซึ่งเป็น API ระดับสูงเพื่อสร้างและฝึกโมเดลใน TensorFlow สำหรับวิธีการอื่นๆ โปรดดูคู่มือ TensorFlow Save and Restore หรือ Saving inความกระตือรือร้น
ติดตั้ง
ติดตั้งและนำเข้า
ติดตั้งและนำเข้า TensorFlow และการพึ่งพา:
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
2.8.0-rc1
รับตัวอย่างชุดข้อมูล
ในการสาธิตวิธีบันทึกและโหลดตุ้มน้ำหนัก คุณจะต้องใช้ ชุดข้อมูล MNIST หากต้องการเร่งความเร็วการวิ่งเหล่านี้ ให้ใช้ 1,000 ตัวอย่างแรก:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
กำหนดรูปแบบ
เริ่มต้นด้วยการสร้างแบบจำลองลำดับอย่างง่าย:
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 dropout (Dropout) (None, 512) 0 dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
บันทึกจุดตรวจระหว่างการฝึก
คุณสามารถใช้แบบจำลองที่ได้รับการฝึกอบรมโดยไม่ต้องฝึกใหม่ หรือรับการฝึกอบรมที่ค้างไว้ในกรณีที่กระบวนการฝึกอบรมถูกขัดจังหวะ การเรียกกลับ tf.keras.callbacks.ModelCheckpoint
ช่วยให้คุณบันทึกโมเดลได้อย่างต่อเนื่องทั้ง ในระหว่าง และเมื่อ สิ้นสุด การฝึก
การใช้งานโทรกลับจุดตรวจ
สร้างการเรียกกลับ tf.keras.callbacks.ModelCheckpoint
ที่ช่วยประหยัดน้ำหนักระหว่างการฝึกเท่านั้น:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10 23/32 [====================>.........] - ETA: 0s - loss: 1.3666 - sparse_categorical_accuracy: 0.6060 Epoch 1: saving model to training_1/cp.ckpt 32/32 [==============================] - 1s 10ms/step - loss: 1.1735 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7180 - val_sparse_categorical_accuracy: 0.7750 Epoch 2/10 24/32 [=====================>........] - ETA: 0s - loss: 0.4238 - sparse_categorical_accuracy: 0.8789 Epoch 2: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.4201 - sparse_categorical_accuracy: 0.8810 - val_loss: 0.5621 - val_sparse_categorical_accuracy: 0.8150 Epoch 3/10 24/32 [=====================>........] - ETA: 0s - loss: 0.2795 - sparse_categorical_accuracy: 0.9336 Epoch 3: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4790 - val_sparse_categorical_accuracy: 0.8430 Epoch 4/10 24/32 [=====================>........] - ETA: 0s - loss: 0.2027 - sparse_categorical_accuracy: 0.9427 Epoch 4: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.2016 - sparse_categorical_accuracy: 0.9440 - val_loss: 0.4361 - val_sparse_categorical_accuracy: 0.8610 Epoch 5/10 24/32 [=====================>........] - ETA: 0s - loss: 0.1739 - sparse_categorical_accuracy: 0.9583 Epoch 5: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1683 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4640 - val_sparse_categorical_accuracy: 0.8580 Epoch 6/10 23/32 [====================>.........] - ETA: 0s - loss: 0.1116 - sparse_categorical_accuracy: 0.9796 Epoch 6: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1125 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8580 Epoch 7/10 24/32 [=====================>........] - ETA: 0s - loss: 0.0978 - sparse_categorical_accuracy: 0.9831 Epoch 7: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4163 - val_sparse_categorical_accuracy: 0.8590 Epoch 8/10 21/32 [==================>...........] - ETA: 0s - loss: 0.0669 - sparse_categorical_accuracy: 0.9911 Epoch 8: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 6ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4411 - val_sparse_categorical_accuracy: 0.8600 Epoch 9/10 22/32 [===================>..........] - ETA: 0s - loss: 0.0495 - sparse_categorical_accuracy: 0.9972 Epoch 9: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4064 - val_sparse_categorical_accuracy: 0.8650 Epoch 10/10 24/32 [=====================>........] - ETA: 0s - loss: 0.0436 - sparse_categorical_accuracy: 0.9948 Epoch 10: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0437 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4061 - val_sparse_categorical_accuracy: 0.8770 <keras.callbacks.History at 0x7eff8d865390>
สิ่งนี้จะสร้างคอลเล็กชันไฟล์จุดตรวจสอบ TensorFlow หนึ่งชุดที่อัปเดตเมื่อสิ้นสุดแต่ละยุค:
os.listdir(checkpoint_dir)
['checkpoint', 'cp.ckpt.index', 'cp.ckpt.data-00000-of-00001']
ตราบใดที่ทั้งสองรุ่นมีสถาปัตยกรรมเดียวกัน คุณก็สามารถแบ่งน้ำหนักระหว่างกันได้ ดังนั้น เมื่อกู้คืนแบบจำลองจากเฉพาะน้ำหนัก ให้สร้างแบบจำลองที่มีสถาปัตยกรรมเดียวกันกับรุ่นดั้งเดิม แล้วตั้งค่าน้ำหนัก
ตอนนี้สร้างแบบจำลองใหม่ที่ยังไม่ผ่านการฝึกอบรมและประเมินในชุดทดสอบ โมเดลที่ไม่ได้รับการฝึกฝนจะแสดงที่ระดับโอกาส (ความแม่นยำประมาณ 10%):
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4473 - sparse_categorical_accuracy: 0.0980 - 145ms/epoch - 5ms/step Untrained model, accuracy: 9.80%
จากนั้นโหลดน้ำหนักจากจุดตรวจและประเมินใหม่:
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4061 - sparse_categorical_accuracy: 0.8770 - 65ms/epoch - 2ms/step Restored model, accuracy: 87.70%
ตัวเลือกการโทรกลับจุดตรวจ
การโทรกลับมีตัวเลือกมากมายในการระบุชื่อที่ไม่ซ้ำสำหรับจุดตรวจและปรับความถี่ของจุดตรวจ
ฝึกโมเดลใหม่และบันทึกจุดตรวจที่มีชื่อไม่ซ้ำกันทุกๆ ห้ายุค:
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt Epoch 10: saving model to training_2/cp-0010.ckpt Epoch 15: saving model to training_2/cp-0015.ckpt Epoch 20: saving model to training_2/cp-0020.ckpt Epoch 25: saving model to training_2/cp-0025.ckpt Epoch 30: saving model to training_2/cp-0030.ckpt Epoch 35: saving model to training_2/cp-0035.ckpt Epoch 40: saving model to training_2/cp-0040.ckpt Epoch 45: saving model to training_2/cp-0045.ckpt Epoch 50: saving model to training_2/cp-0050.ckpt <keras.callbacks.History at 0x7eff807703d0>
ตอนนี้ให้ดูที่จุดตรวจที่เกิดขึ้นและเลือกจุดล่าสุด:
os.listdir(checkpoint_dir)
['cp-0005.ckpt.data-00000-of-00001', 'cp-0050.ckpt.index', 'checkpoint', 'cp-0010.ckpt.index', 'cp-0035.ckpt.data-00000-of-00001', 'cp-0000.ckpt.data-00000-of-00001', 'cp-0050.ckpt.data-00000-of-00001', 'cp-0010.ckpt.data-00000-of-00001', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0035.ckpt.index', 'cp-0040.ckpt.index', 'cp-0025.ckpt.data-00000-of-00001', 'cp-0045.ckpt.index', 'cp-0020.ckpt.index', 'cp-0025.ckpt.index', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0030.ckpt.index', 'cp-0000.ckpt.index', 'cp-0045.ckpt.data-00000-of-00001', 'cp-0015.ckpt.index', 'cp-0015.ckpt.data-00000-of-00001', 'cp-0005.ckpt.index', 'cp-0040.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'
หากต้องการทดสอบ ให้รีเซ็ตโมเดลและโหลดจุดตรวจสอบล่าสุด:
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 150ms/epoch - 5ms/step Restored model, accuracy: 87.70%ตัวยึดตำแหน่ง22
ไฟล์เหล่านี้คืออะไร?
รหัสด้านบนเก็บน้ำหนักไว้ในคอลเลกชันของไฟล์ที่จัดรูปแบบ จุดตรวจสอบ ที่มีเฉพาะน้ำหนักที่ฝึกแล้วในรูปแบบไบนารี จุดตรวจประกอบด้วย:
- ชาร์ดอย่างน้อยหนึ่งรายการที่มีตุ้มน้ำหนักของโมเดลของคุณ
- ไฟล์ดัชนีที่ระบุว่าน้ำหนักใดถูกเก็บไว้ในชาร์ดใด
หากคุณกำลังฝึกโมเดลในเครื่องเดียว คุณจะมีชาร์ดหนึ่งส่วนที่มีส่วนต่อท้าย: .data-00000-of-00001
ลดน้ำหนักด้วยตนเอง
การบันทึกน้ำหนักด้วยตนเองด้วยเมธอด Model.save_weights
โดยค่าเริ่มต้น tf.keras
— และโดยเฉพาะ save_weights
ใช้รูปแบบ จุดตรวจสอบ TensorFlow ที่มีนามสกุล .ckpt
(การบันทึกใน HDF5 ด้วยนามสกุล .h5
จะครอบคลุมอยู่ในคู่มือ บันทึกและกำหนดรูปแบบอนุกรม ):
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 143ms/epoch - 4ms/step Restored model, accuracy: 87.70%
บันทึกโมเดลทั้งหมด
เรียก model.save
เพื่อบันทึกสถาปัตยกรรม น้ำหนัก และการกำหนดค่าการฝึกของโมเดลในไฟล์/โฟลเดอร์เดียว สิ่งนี้ทำให้คุณสามารถเอ็กซ์พอร์ตโมเดลเพื่อให้ใช้งานได้โดยไม่ต้องเข้าถึงโค้ด Python ดั้งเดิม* เนื่องจากสถานะเครื่องมือเพิ่มประสิทธิภาพถูกกู้คืน คุณจึงสามารถกลับมาฝึกต่อได้จากจุดที่ค้างไว้
โมเดลทั้งหมดสามารถบันทึกได้ในรูปแบบไฟล์ที่แตกต่างกันสองรูปแบบ ( SavedModel
และ HDF5
) รูปแบบ SavedModel
เป็นรูปแบบไฟล์เริ่มต้นใน TF2.x อย่างไรก็ตาม สามารถบันทึกโมเดลต่างๆ ในรูปแบบ HDF5
ได้ รายละเอียดเพิ่มเติมเกี่ยวกับการบันทึกโมเดลทั้งหมดในรูปแบบไฟล์ทั้งสองมีอธิบายไว้ด้านล่าง
การบันทึกโมเดลที่ใช้งานได้อย่างสมบูรณ์มีประโยชน์มาก คุณสามารถโหลดโมเดลเหล่านี้ใน TensorFlow.js ( Saved Model , HDF5 ) จากนั้นฝึกและเรียกใช้ในเว็บเบราว์เซอร์ หรือแปลงให้ทำงานบนอุปกรณ์มือถือโดยใช้ TensorFlow Lite ( Saved Model , HDF5 )
*อ็อบเจ็กต์ที่กำหนดเอง (เช่น โมเดลหรือเลเยอร์ย่อย) ต้องให้ความสนใจเป็นพิเศษเมื่อทำการบันทึกและโหลด ดูส่วนการ บันทึกวัตถุที่กำหนดเอง ด้านล่าง
รูปแบบโมเดลที่บันทึกไว้
รูปแบบ SavedModel เป็นอีกวิธีหนึ่งในการทำให้โมเดลเป็นอนุกรม โมเดลที่บันทึกในรูปแบบนี้สามารถกู้คืนได้โดยใช้ tf.keras.models.load_model
และเข้ากันได้กับ TensorFlow Serving คู่มือ SavedModel จะลงรายละเอียดเกี่ยวกับวิธีการให้บริการ/ตรวจสอบ SavedModel ส่วนด้านล่างแสดงขั้นตอนในการบันทึกและกู้คืนโมเดล
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1988 - sparse_categorical_accuracy: 0.6550 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4180 - sparse_categorical_accuracy: 0.8930 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9220 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9540 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1593 - sparse_categorical_accuracy: 0.9630 2022-01-26 07:30:22.888387: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate INFO:tensorflow:Assets written to: saved_model/my_model/assets
รูปแบบ SavedModel เป็นไดเร็กทอรีที่มีไบนารี protobuf และจุดตรวจสอบ TensorFlow ตรวจสอบไดเร็กทอรีโมเดลที่บันทึกไว้:
# my_model directory
ls saved_model
# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model assets keras_metadata.pb saved_model.pb variables
โหลดโมเดล Keras ใหม่จากโมเดลที่บันทึกไว้:
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_10 (Dense) (None, 512) 401920 dropout_5 (Dropout) (None, 512) 0 dense_11 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
โมเดลที่กู้คืนถูกคอมไพล์ด้วยอาร์กิวเมนต์เดียวกันกับโมเดลดั้งเดิม ลองรันประเมินและคาดการณ์ด้วยโมเดลที่โหลด:
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4577 - sparse_categorical_accuracy: 0.8430 - 156ms/epoch - 5ms/step Restored model, accuracy: 84.30% (1000, 10)ตัวยึดตำแหน่ง32
รูปแบบ HDF5
Keras จัดเตรียมรูปแบบการบันทึกพื้นฐานโดยใช้มาตรฐาน HDF5
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1383 - sparse_categorical_accuracy: 0.6970 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4094 - sparse_categorical_accuracy: 0.8920 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2936 - sparse_categorical_accuracy: 0.9160 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2050 - sparse_categorical_accuracy: 0.9460 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9690
ตอนนี้ สร้างโมเดลใหม่จากไฟล์นั้น:
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_12 (Dense) (None, 512) 401920 dropout_6 (Dropout) (None, 512) 0 dense_13 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________ตัวยึดตำแหน่ง36
ตรวจสอบความถูกต้อง:
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4266 - sparse_categorical_accuracy: 0.8620 - 141ms/epoch - 4ms/step Restored model, accuracy: 86.20%
Keras บันทึกโมเดลด้วยการตรวจสอบสถาปัตยกรรมของพวกเขา เทคนิคนี้บันทึกทุกอย่าง:
- ค่าน้ำหนัก
- สถาปัตยกรรมของโมเดล
- การกำหนดค่าการฝึกของโมเดล (สิ่งที่คุณส่งผ่านไปยัง
.compile()
) - เครื่องมือเพิ่มประสิทธิภาพและสถานะของอุปกรณ์ (หากมี) (ซึ่งจะทำให้คุณสามารถเริ่มการฝึกใหม่จากจุดที่ค้างไว้ได้)
Keras ไม่สามารถบันทึกเครื่องมือเพิ่มประสิทธิภาพ v1.x
(จาก tf.compat.v1.train
) เนื่องจากไม่สามารถทำงานร่วมกับจุดตรวจได้ สำหรับตัวเพิ่มประสิทธิภาพ v1.x คุณต้องคอมไพล์โมเดลใหม่หลังจากโหลด—สูญเสียสถานะของตัวเพิ่มประสิทธิภาพ
กำลังบันทึกวัตถุที่กำหนดเอง
หากคุณกำลังใช้รูปแบบ SavedModel คุณสามารถข้ามส่วนนี้ได้ ความแตกต่างที่สำคัญระหว่าง HDF5 และ SavedModel คือ HDF5 ใช้การกำหนดค่าอ็อบเจ็กต์เพื่อบันทึกสถาปัตยกรรมโมเดล ในขณะที่ SavedModel จะบันทึกกราฟการดำเนินการ ดังนั้น SavedModels จึงสามารถบันทึกอ็อบเจ็กต์ที่กำหนดเองได้ เช่น โมเดลย่อยและเลเยอร์ที่กำหนดเองโดยไม่ต้องใช้โค้ดต้นฉบับ
ในการบันทึกออบเจ็กต์ที่กำหนดเองไปยัง HDF5 คุณต้องทำดังต่อไปนี้:
- กำหนดเมธอด
get_config
ในอ็อบเจ็กต์ของคุณ และเป็นทางเลือกfrom_config
classmethod-
get_config(self)
ส่งคืนพจนานุกรมพารามิเตอร์ JSON-serializable ที่จำเป็นในการสร้างวัตถุขึ้นใหม่ -
from_config(cls, config)
ใช้การกำหนดค่าที่ส่งคืนจากget_config
เพื่อสร้างวัตถุใหม่ โดยค่าเริ่มต้น ฟังก์ชันนี้จะใช้การกำหนดค่าเป็นค่าเริ่มต้น kwargs (return cls(**config)
)
-
- ส่งอ็อบเจ็กต์ไปยังอาร์กิวเมนต์
custom_objects
เมื่อโหลดโมเดล อาร์กิวเมนต์ต้องเป็นพจนานุกรมที่จับคู่ชื่อคลาสสตริงกับคลาส Python เช่นtf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
ดูบทแนะนำ การเขียนเลเยอร์และโมเดลตั้งแต่เริ่มต้น สำหรับตัวอย่างของอ็อบเจกต์ที่กำหนดเองและ get_config
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.