ذخیره و بارگذاری مدل ها

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

پیشرفت مدل را می توان در حین و بعد از آموزش ذخیره کرد. این بدان معناست که یک مدل می‌تواند از همان جایی که کارش را متوقف کرده است، از سر بگیرد و از زمان‌های طولانی آموزش اجتناب کند. ذخیره همچنین به این معنی است که شما می توانید مدل خود را به اشتراک بگذارید و دیگران می توانند کار شما را بازسازی کنند. هنگام انتشار مدل‌ها و تکنیک‌های تحقیق، اکثر متخصصان یادگیری ماشینی به اشتراک می‌گذارند:

  • کد برای ایجاد مدل، و
  • وزنه ها یا پارامترهای آموزش دیده برای مدل

اشتراک‌گذاری این داده‌ها به دیگران کمک می‌کند تا بفهمند مدل چگونه کار می‌کند و خودشان آن را با داده‌های جدید امتحان کنند.

گزینه ها

بسته به APIی که استفاده می کنید، روش های مختلفی برای ذخیره مدل های TensorFlow وجود دارد. این راهنما از tf.keras ، یک API سطح بالا برای ساخت و آموزش مدل‌ها در TensorFlow استفاده می‌کند. برای سایر رویکردها به راهنمای ذخیره و بازیابی TensorFlow یا Saving in eager مراجعه کنید.

برپایی

نصب و واردات

نصب و وارد کردن TensorFlow و وابستگی ها:

pip install pyyaml h5py  # Required to save models in HDF5 format
import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)
2.8.0-rc1

یک مجموعه داده نمونه دریافت کنید

برای نشان دادن نحوه ذخیره و بارگیری وزن ها، از مجموعه داده MNIST استفاده می کنید. برای سرعت بخشیدن به این اجراها، از 1000 مثال اول استفاده کنید:

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

یک مدل تعریف کنید

با ساختن یک مدل متوالی ساده شروع کنید:

# Define a simple sequential model
def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

# Create a basic model instance
model = create_model()

# Display the model's architecture
model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense (Dense)               (None, 512)               401920    
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

ایست های بازرسی را در طول آموزش ذخیره کنید

می‌توانید بدون نیاز به آموزش مجدد از یک مدل آموزش‌دیده استفاده کنید، یا در صورت قطع شدن فرآیند آموزش، از جایی که آن را متوقف کرده‌اید، از آن استفاده کنید. پاسخ تماس tf.keras.callbacks.ModelCheckpoint به شما این امکان را می دهد که به طور مداوم مدل را هم در حین و هم در پایان آموزش ذخیره کنید.

استفاده از ایست بازرسی

یک پاسخ تماس tf.keras.callbacks.ModelCheckpoint ایجاد کنید که فقط در طول تمرین وزنه ها را ذخیره می کند:

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

# Train the model with the new callback
model.fit(train_images, 
          train_labels,  
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])  # Pass callback to training

# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10
23/32 [====================>.........] - ETA: 0s - loss: 1.3666 - sparse_categorical_accuracy: 0.6060 
Epoch 1: saving model to training_1/cp.ckpt
32/32 [==============================] - 1s 10ms/step - loss: 1.1735 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7180 - val_sparse_categorical_accuracy: 0.7750
Epoch 2/10
24/32 [=====================>........] - ETA: 0s - loss: 0.4238 - sparse_categorical_accuracy: 0.8789
Epoch 2: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.4201 - sparse_categorical_accuracy: 0.8810 - val_loss: 0.5621 - val_sparse_categorical_accuracy: 0.8150
Epoch 3/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2795 - sparse_categorical_accuracy: 0.9336
Epoch 3: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4790 - val_sparse_categorical_accuracy: 0.8430
Epoch 4/10
24/32 [=====================>........] - ETA: 0s - loss: 0.2027 - sparse_categorical_accuracy: 0.9427
Epoch 4: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.2016 - sparse_categorical_accuracy: 0.9440 - val_loss: 0.4361 - val_sparse_categorical_accuracy: 0.8610
Epoch 5/10
24/32 [=====================>........] - ETA: 0s - loss: 0.1739 - sparse_categorical_accuracy: 0.9583
Epoch 5: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1683 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4640 - val_sparse_categorical_accuracy: 0.8580
Epoch 6/10
23/32 [====================>.........] - ETA: 0s - loss: 0.1116 - sparse_categorical_accuracy: 0.9796
Epoch 6: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.1125 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8580
Epoch 7/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0978 - sparse_categorical_accuracy: 0.9831
Epoch 7: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4163 - val_sparse_categorical_accuracy: 0.8590
Epoch 8/10
21/32 [==================>...........] - ETA: 0s - loss: 0.0669 - sparse_categorical_accuracy: 0.9911
Epoch 8: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 6ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4411 - val_sparse_categorical_accuracy: 0.8600
Epoch 9/10
22/32 [===================>..........] - ETA: 0s - loss: 0.0495 - sparse_categorical_accuracy: 0.9972
Epoch 9: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4064 - val_sparse_categorical_accuracy: 0.8650
Epoch 10/10
24/32 [=====================>........] - ETA: 0s - loss: 0.0436 - sparse_categorical_accuracy: 0.9948
Epoch 10: saving model to training_1/cp.ckpt
32/32 [==============================] - 0s 5ms/step - loss: 0.0437 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4061 - val_sparse_categorical_accuracy: 0.8770
<keras.callbacks.History at 0x7eff8d865390>

این یک مجموعه واحد از فایل‌های پست بازرسی TensorFlow ایجاد می‌کند که در پایان هر دوره به‌روزرسانی می‌شوند:

os.listdir(checkpoint_dir)
['checkpoint', 'cp.ckpt.index', 'cp.ckpt.data-00000-of-00001']

تا زمانی که دو مدل معماری یکسانی دارند، می توانید وزن ها را بین آنها به اشتراک بگذارید. بنابراین، هنگام بازیابی یک مدل از وزن فقط، یک مدل با معماری مشابه مدل اصلی ایجاد کنید و سپس وزن آن را تنظیم کنید.

اکنون یک مدل تازه و آموزش ندیده را بازسازی کنید و آن را در مجموعه آزمایشی ارزیابی کنید. یک مدل آموزش ندیده در سطوح شانسی (~10% دقت) عمل می کند:

# Create a basic model instance
model = create_model()

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4473 - sparse_categorical_accuracy: 0.0980 - 145ms/epoch - 5ms/step
Untrained model, accuracy:  9.80%

سپس وزنه ها را از ایست بازرسی بارگیری کنید و دوباره ارزیابی کنید:

# Loads the weights
model.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4061 - sparse_categorical_accuracy: 0.8770 - 65ms/epoch - 2ms/step
Restored model, accuracy: 87.70%

گزینه های برگشت به تماس ایست بازرسی

پاسخ به تماس چندین گزینه برای ارائه نام‌های منحصر به فرد برای پست‌های بازرسی و تنظیم فرکانس بازرسی فراهم می‌کند.

یک مدل جدید آموزش دهید و هر پنج دوره یک بار پست های بازرسی با نام منحصر به فرد را ذخیره کنید:

# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq=5*batch_size)

# Create a new model instance
model = create_model()

# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          epochs=50, 
          batch_size=batch_size, 
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt

Epoch 10: saving model to training_2/cp-0010.ckpt

Epoch 15: saving model to training_2/cp-0015.ckpt

Epoch 20: saving model to training_2/cp-0020.ckpt

Epoch 25: saving model to training_2/cp-0025.ckpt

Epoch 30: saving model to training_2/cp-0030.ckpt

Epoch 35: saving model to training_2/cp-0035.ckpt

Epoch 40: saving model to training_2/cp-0040.ckpt

Epoch 45: saving model to training_2/cp-0045.ckpt

Epoch 50: saving model to training_2/cp-0050.ckpt
<keras.callbacks.History at 0x7eff807703d0>

اکنون، به نقاط بازرسی حاصل نگاه کنید و آخرین مورد را انتخاب کنید:

os.listdir(checkpoint_dir)
['cp-0005.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.index',
 'checkpoint',
 'cp-0010.ckpt.index',
 'cp-0035.ckpt.data-00000-of-00001',
 'cp-0000.ckpt.data-00000-of-00001',
 'cp-0050.ckpt.data-00000-of-00001',
 'cp-0010.ckpt.data-00000-of-00001',
 'cp-0020.ckpt.data-00000-of-00001',
 'cp-0035.ckpt.index',
 'cp-0040.ckpt.index',
 'cp-0025.ckpt.data-00000-of-00001',
 'cp-0045.ckpt.index',
 'cp-0020.ckpt.index',
 'cp-0025.ckpt.index',
 'cp-0030.ckpt.data-00000-of-00001',
 'cp-0030.ckpt.index',
 'cp-0000.ckpt.index',
 'cp-0045.ckpt.data-00000-of-00001',
 'cp-0015.ckpt.index',
 'cp-0015.ckpt.data-00000-of-00001',
 'cp-0005.ckpt.index',
 'cp-0040.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'

برای آزمایش، مدل را بازنشانی کنید و آخرین نقطه بازرسی را بارگیری کنید:

# Create a new model instance
model = create_model()

# Load the previously saved weights
model.load_weights(latest)

# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 150ms/epoch - 5ms/step
Restored model, accuracy: 87.70%

این فایل ها چیست؟

کد بالا وزن‌ها را در مجموعه‌ای از فایل‌های با فرمت نقطه بازرسی ذخیره می‌کند که فقط وزن‌های آموزش‌دیده‌شده را در قالب باینری دارند. پست های بازرسی شامل:

  • یک یا چند قطعه که حاوی وزنه های مدل شما هستند.
  • یک فایل فهرست که نشان می دهد وزن ها در کدام قطعه ذخیره می شوند.

اگر در حال آموزش یک مدل بر روی یک ماشین هستید، یک قطعه با پسوند .data-00000-of-00001 خواهید داشت.

وزنه ها را به صورت دستی ذخیره کنید

ذخیره وزن به صورت دستی با روش Model.save_weights . به‌طور پیش‌فرض، tf.keras – و save_weights – از فرمت نقطه بازرسی .ckpt با پسوند ckpt. استفاده می‌کند (ذخیره در HDF5 با پسوند .h5 در راهنمای Save and serialize models پوشش داده شده است):

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Create a new model instance
model = create_model()

# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')

# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 143ms/epoch - 4ms/step
Restored model, accuracy: 87.70%

کل مدل را ذخیره کنید

برای ذخیره معماری، وزن و پیکربندی آموزشی مدل در یک فایل/پوشه واحد، model.save را فراخوانی کنید. این به شما امکان می دهد یک مدل را صادر کنید تا بدون دسترسی به کد اصلی پایتون* از آن استفاده کنید. از آنجایی که حالت بهینه ساز بازیابی شده است، می توانید تمرین را دقیقا از همان جایی که متوقف کرده اید از سر بگیرید.

کل مدل را می توان در دو فرمت فایل مختلف ( SavedModel و HDF5 ) ذخیره کرد. فرمت SavedModel فرمت فایل پیش فرض در TF2.x است. با این حال، مدل ها را می توان در فرمت HDF5 ذخیره کرد. جزئیات بیشتر در مورد ذخیره کل مدل ها در دو فرمت فایل در زیر توضیح داده شده است.

ذخیره یک مدل کاملاً کاربردی بسیار مفید است—شما می توانید آنها را در TensorFlow.js بارگذاری کنید ( مدل ذخیره شده ، HDF5 ) و سپس آنها را آموزش دهید و در مرورگرهای وب اجرا کنید، یا با استفاده از TensorFlow Lite آنها را برای اجرا در دستگاه های تلفن همراه تبدیل کنید ( مدل ذخیره شده ، HDF5 )

*اشیاء سفارشی (مثلاً مدل‌ها یا لایه‌های طبقه‌بندی‌شده) به توجه ویژه در هنگام ذخیره و بارگذاری نیاز دارند. بخش ذخیره اشیاء سفارشی را در زیر ببینید

فرمت SavedModel

فرمت SavedModel راه دیگری برای سریال سازی مدل ها است. مدل های ذخیره شده در این قالب را می توان با استفاده از tf.keras.models.load_model بازیابی کرد و با سرویس TensorFlow سازگار است. راهنمای SavedModel به جزئیات نحوه سرویس/بازرسی SavedModel می‌پردازد. بخش زیر مراحل ذخیره و بازیابی مدل را نشان می دهد.

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1988 - sparse_categorical_accuracy: 0.6550
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4180 - sparse_categorical_accuracy: 0.8930
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9220
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9540
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1593 - sparse_categorical_accuracy: 0.9630
2022-01-26 07:30:22.888387: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function.
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay
WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate
INFO:tensorflow:Assets written to: saved_model/my_model/assets

فرمت SavedModel یک دایرکتوری حاوی یک پروتوباف باینری و یک چک پوینت TensorFlow است. دایرکتوری مدل ذخیره شده را بررسی کنید:

# my_model directory
ls saved_model

# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model
assets  keras_metadata.pb  saved_model.pb  variables

یک مدل Keras تازه از مدل ذخیره شده را بارگیری مجدد کنید:

new_model = tf.keras.models.load_model('saved_model/my_model')

# Check its architecture
new_model.summary()
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_10 (Dense)            (None, 512)               401920    
                                                                 
 dropout_5 (Dropout)         (None, 512)               0         
                                                                 
 dense_11 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

مدل بازیابی شده با همان آرگومان های مدل اصلی کامپایل می شود. اجرای ارزیابی و پیش بینی با مدل بارگذاری شده را امتحان کنید:

# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4577 - sparse_categorical_accuracy: 0.8430 - 156ms/epoch - 5ms/step
Restored model, accuracy: 84.30%
(1000, 10)

فرمت HDF5

Keras یک فرمت ذخیره اولیه را با استفاده از استاندارد HDF5 ارائه می دهد.

# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)

# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5
32/32 [==============================] - 0s 2ms/step - loss: 1.1383 - sparse_categorical_accuracy: 0.6970
Epoch 2/5
32/32 [==============================] - 0s 2ms/step - loss: 0.4094 - sparse_categorical_accuracy: 0.8920
Epoch 3/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2936 - sparse_categorical_accuracy: 0.9160
Epoch 4/5
32/32 [==============================] - 0s 2ms/step - loss: 0.2050 - sparse_categorical_accuracy: 0.9460
Epoch 5/5
32/32 [==============================] - 0s 2ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9690

حالا مدل را از آن فایل دوباره بسازید:

# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')

# Show the model architecture
new_model.summary()
Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_12 (Dense)            (None, 512)               401920    
                                                                 
 dropout_6 (Dropout)         (None, 512)               0         
                                                                 
 dense_13 (Dense)            (None, 10)                5130      
                                                                 
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

دقت آن را بررسی کنید:

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4266 - sparse_categorical_accuracy: 0.8620 - 141ms/epoch - 4ms/step
Restored model, accuracy: 86.20%

Keras مدل ها را با بررسی معماری آنها ذخیره می کند. این تکنیک همه چیز را نجات می دهد:

  • ارزش های وزنی
  • معماری مدل
  • پیکربندی آموزشی مدل (آنچه به روش .compile() منتقل می کنید)
  • بهینه ساز و وضعیت آن، در صورت وجود (این به شما امکان می دهد آموزش را از جایی که متوقف کردید دوباره شروع کنید)

Keras قادر به ذخیره بهینه سازهای v1.x (از tf.compat.v1.train ) نیست زیرا با نقاط بازرسی سازگار نیستند. برای بهینه‌سازهای v1.x، باید پس از بارگذاری، مدل را دوباره کامپایل کنید—از دست دادن حالت بهینه‌ساز.

ذخیره اشیاء سفارشی

اگر از فرمت SavedModel استفاده می کنید، می توانید از این بخش صرف نظر کنید. تفاوت اصلی بین HDF5 و SavedModel این است که HDF5 از تنظیمات شی برای ذخیره معماری مدل استفاده می کند، در حالی که SavedModel نمودار اجرا را ذخیره می کند. بنابراین، SavedModels می‌تواند اشیاء سفارشی مانند مدل‌های زیر کلاس و لایه‌های سفارشی را بدون نیاز به کد اصلی ذخیره کند.

برای ذخیره اشیاء سفارشی در HDF5، باید موارد زیر را انجام دهید:

  1. یک متد get_config را در شیء خود تعریف کنید، و به صورت اختیاری یک from_config config.
    • get_config(self) یک فرهنگ لغت قابل سریال‌سازی با JSON از پارامترهای مورد نیاز برای ایجاد مجدد شی را برمی‌گرداند.
    • from_config(cls, config) از پیکربندی برگشتی از get_config برای ایجاد یک شی جدید استفاده می کند. به طور پیش فرض، این تابع از پیکربندی به عنوان کوارگ های اولیه استفاده می کند ( return cls(**config) ).
  2. هنگام بارگذاری مدل، شی را به آرگومان custom_objects کنید. آرگومان باید یک فرهنگ لغت باشد که نام کلاس رشته را به کلاس پایتون نگاشت می کند. به عنوان مثال tf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})

برای نمونه هایی از اشیاء سفارشی و get_config به آموزش نوشتن لایه ها و مدل ها از ابتدا مراجعه کنید.

# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.