مشاهده در TensorFlow.org | در Google Colab اجرا شود | مشاهده منبع در GitHub | دانلود دفترچه یادداشت |
پیشرفت مدل را می توان در حین و بعد از آموزش ذخیره کرد. این بدان معناست که یک مدل میتواند از همان جایی که کارش را متوقف کرده است، از سر بگیرد و از زمانهای طولانی آموزش اجتناب کند. ذخیره همچنین به این معنی است که شما می توانید مدل خود را به اشتراک بگذارید و دیگران می توانند کار شما را بازسازی کنند. هنگام انتشار مدلها و تکنیکهای تحقیق، اکثر متخصصان یادگیری ماشینی به اشتراک میگذارند:
- کد برای ایجاد مدل، و
- وزنه ها یا پارامترهای آموزش دیده برای مدل
اشتراکگذاری این دادهها به دیگران کمک میکند تا بفهمند مدل چگونه کار میکند و خودشان آن را با دادههای جدید امتحان کنند.
گزینه ها
بسته به APIی که استفاده می کنید، روش های مختلفی برای ذخیره مدل های TensorFlow وجود دارد. این راهنما از tf.keras ، یک API سطح بالا برای ساخت و آموزش مدلها در TensorFlow استفاده میکند. برای سایر رویکردها به راهنمای ذخیره و بازیابی TensorFlow یا Saving in eager مراجعه کنید.
برپایی
نصب و واردات
نصب و وارد کردن TensorFlow و وابستگی ها:
pip install pyyaml h5py # Required to save models in HDF5 format
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
2.8.0-rc1
یک مجموعه داده نمونه دریافت کنید
برای نشان دادن نحوه ذخیره و بارگیری وزن ها، از مجموعه داده MNIST استفاده می کنید. برای سرعت بخشیدن به این اجراها، از 1000 مثال اول استفاده کنید:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
یک مدل تعریف کنید
با ساختن یک مدل متوالی ساده شروع کنید:
# Define a simple sequential model
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
# Create a basic model instance
model = create_model()
# Display the model's architecture
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 512) 401920 dropout (Dropout) (None, 512) 0 dense_1 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
ایست های بازرسی را در طول آموزش ذخیره کنید
میتوانید بدون نیاز به آموزش مجدد از یک مدل آموزشدیده استفاده کنید، یا در صورت قطع شدن فرآیند آموزش، از جایی که آن را متوقف کردهاید، از آن استفاده کنید. پاسخ تماس tf.keras.callbacks.ModelCheckpoint
به شما این امکان را می دهد که به طور مداوم مدل را هم در حین و هم در پایان آموزش ذخیره کنید.
استفاده از ایست بازرسی
یک پاسخ تماس tf.keras.callbacks.ModelCheckpoint
ایجاد کنید که فقط در طول تمرین وزنه ها را ذخیره می کند:
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback]) # Pass callback to training
# This may generate warnings related to saving the state of the optimizer.
# These warnings (and similar warnings throughout this notebook)
# are in place to discourage outdated usage, and can be ignored.
Epoch 1/10 23/32 [====================>.........] - ETA: 0s - loss: 1.3666 - sparse_categorical_accuracy: 0.6060 Epoch 1: saving model to training_1/cp.ckpt 32/32 [==============================] - 1s 10ms/step - loss: 1.1735 - sparse_categorical_accuracy: 0.6690 - val_loss: 0.7180 - val_sparse_categorical_accuracy: 0.7750 Epoch 2/10 24/32 [=====================>........] - ETA: 0s - loss: 0.4238 - sparse_categorical_accuracy: 0.8789 Epoch 2: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.4201 - sparse_categorical_accuracy: 0.8810 - val_loss: 0.5621 - val_sparse_categorical_accuracy: 0.8150 Epoch 3/10 24/32 [=====================>........] - ETA: 0s - loss: 0.2795 - sparse_categorical_accuracy: 0.9336 Epoch 3: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.2815 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.4790 - val_sparse_categorical_accuracy: 0.8430 Epoch 4/10 24/32 [=====================>........] - ETA: 0s - loss: 0.2027 - sparse_categorical_accuracy: 0.9427 Epoch 4: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.2016 - sparse_categorical_accuracy: 0.9440 - val_loss: 0.4361 - val_sparse_categorical_accuracy: 0.8610 Epoch 5/10 24/32 [=====================>........] - ETA: 0s - loss: 0.1739 - sparse_categorical_accuracy: 0.9583 Epoch 5: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1683 - sparse_categorical_accuracy: 0.9610 - val_loss: 0.4640 - val_sparse_categorical_accuracy: 0.8580 Epoch 6/10 23/32 [====================>.........] - ETA: 0s - loss: 0.1116 - sparse_categorical_accuracy: 0.9796 Epoch 6: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.1125 - sparse_categorical_accuracy: 0.9780 - val_loss: 0.4420 - val_sparse_categorical_accuracy: 0.8580 Epoch 7/10 24/32 [=====================>........] - ETA: 0s - loss: 0.0978 - sparse_categorical_accuracy: 0.9831 Epoch 7: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0989 - sparse_categorical_accuracy: 0.9820 - val_loss: 0.4163 - val_sparse_categorical_accuracy: 0.8590 Epoch 8/10 21/32 [==================>...........] - ETA: 0s - loss: 0.0669 - sparse_categorical_accuracy: 0.9911 Epoch 8: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 6ms/step - loss: 0.0690 - sparse_categorical_accuracy: 0.9910 - val_loss: 0.4411 - val_sparse_categorical_accuracy: 0.8600 Epoch 9/10 22/32 [===================>..........] - ETA: 0s - loss: 0.0495 - sparse_categorical_accuracy: 0.9972 Epoch 9: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0516 - sparse_categorical_accuracy: 0.9950 - val_loss: 0.4064 - val_sparse_categorical_accuracy: 0.8650 Epoch 10/10 24/32 [=====================>........] - ETA: 0s - loss: 0.0436 - sparse_categorical_accuracy: 0.9948 Epoch 10: saving model to training_1/cp.ckpt 32/32 [==============================] - 0s 5ms/step - loss: 0.0437 - sparse_categorical_accuracy: 0.9960 - val_loss: 0.4061 - val_sparse_categorical_accuracy: 0.8770 <keras.callbacks.History at 0x7eff8d865390>
این یک مجموعه واحد از فایلهای پست بازرسی TensorFlow ایجاد میکند که در پایان هر دوره بهروزرسانی میشوند:
os.listdir(checkpoint_dir)
['checkpoint', 'cp.ckpt.index', 'cp.ckpt.data-00000-of-00001']
تا زمانی که دو مدل معماری یکسانی دارند، می توانید وزن ها را بین آنها به اشتراک بگذارید. بنابراین، هنگام بازیابی یک مدل از وزن فقط، یک مدل با معماری مشابه مدل اصلی ایجاد کنید و سپس وزن آن را تنظیم کنید.
اکنون یک مدل تازه و آموزش ندیده را بازسازی کنید و آن را در مجموعه آزمایشی ارزیابی کنید. یک مدل آموزش ندیده در سطوح شانسی (~10% دقت) عمل می کند:
# Create a basic model instance
model = create_model()
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 2.4473 - sparse_categorical_accuracy: 0.0980 - 145ms/epoch - 5ms/step Untrained model, accuracy: 9.80%
سپس وزنه ها را از ایست بازرسی بارگیری کنید و دوباره ارزیابی کنید:
# Loads the weights
model.load_weights(checkpoint_path)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4061 - sparse_categorical_accuracy: 0.8770 - 65ms/epoch - 2ms/step Restored model, accuracy: 87.70%
گزینه های برگشت به تماس ایست بازرسی
پاسخ به تماس چندین گزینه برای ارائه نامهای منحصر به فرد برای پستهای بازرسی و تنظیم فرکانس بازرسی فراهم میکند.
یک مدل جدید آموزش دهید و هر پنج دوره یک بار پست های بازرسی با نام منحصر به فرد را ذخیره کنید:
# Include the epoch in the file name (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
# Create a callback that saves the model's weights every 5 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
# Create a new model instance
model = create_model()
# Save the weights using the `checkpoint_path` format
model.save_weights(checkpoint_path.format(epoch=0))
# Train the model with the new callback
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
Epoch 5: saving model to training_2/cp-0005.ckpt Epoch 10: saving model to training_2/cp-0010.ckpt Epoch 15: saving model to training_2/cp-0015.ckpt Epoch 20: saving model to training_2/cp-0020.ckpt Epoch 25: saving model to training_2/cp-0025.ckpt Epoch 30: saving model to training_2/cp-0030.ckpt Epoch 35: saving model to training_2/cp-0035.ckpt Epoch 40: saving model to training_2/cp-0040.ckpt Epoch 45: saving model to training_2/cp-0045.ckpt Epoch 50: saving model to training_2/cp-0050.ckpt <keras.callbacks.History at 0x7eff807703d0>
اکنون، به نقاط بازرسی حاصل نگاه کنید و آخرین مورد را انتخاب کنید:
os.listdir(checkpoint_dir)
['cp-0005.ckpt.data-00000-of-00001', 'cp-0050.ckpt.index', 'checkpoint', 'cp-0010.ckpt.index', 'cp-0035.ckpt.data-00000-of-00001', 'cp-0000.ckpt.data-00000-of-00001', 'cp-0050.ckpt.data-00000-of-00001', 'cp-0010.ckpt.data-00000-of-00001', 'cp-0020.ckpt.data-00000-of-00001', 'cp-0035.ckpt.index', 'cp-0040.ckpt.index', 'cp-0025.ckpt.data-00000-of-00001', 'cp-0045.ckpt.index', 'cp-0020.ckpt.index', 'cp-0025.ckpt.index', 'cp-0030.ckpt.data-00000-of-00001', 'cp-0030.ckpt.index', 'cp-0000.ckpt.index', 'cp-0045.ckpt.data-00000-of-00001', 'cp-0015.ckpt.index', 'cp-0015.ckpt.data-00000-of-00001', 'cp-0005.ckpt.index', 'cp-0040.ckpt.data-00000-of-00001']
latest = tf.train.latest_checkpoint(checkpoint_dir)
latest
'training_2/cp-0050.ckpt'
برای آزمایش، مدل را بازنشانی کنید و آخرین نقطه بازرسی را بارگیری کنید:
# Create a new model instance
model = create_model()
# Load the previously saved weights
model.load_weights(latest)
# Re-evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 150ms/epoch - 5ms/step Restored model, accuracy: 87.70%
این فایل ها چیست؟
کد بالا وزنها را در مجموعهای از فایلهای با فرمت نقطه بازرسی ذخیره میکند که فقط وزنهای آموزشدیدهشده را در قالب باینری دارند. پست های بازرسی شامل:
- یک یا چند قطعه که حاوی وزنه های مدل شما هستند.
- یک فایل فهرست که نشان می دهد وزن ها در کدام قطعه ذخیره می شوند.
اگر در حال آموزش یک مدل بر روی یک ماشین هستید، یک قطعه با پسوند .data-00000-of-00001
خواهید داشت.
وزنه ها را به صورت دستی ذخیره کنید
ذخیره وزن به صورت دستی با روش Model.save_weights
. بهطور پیشفرض، tf.keras
– و save_weights
– از فرمت نقطه بازرسی .ckpt
با پسوند ckpt. استفاده میکند (ذخیره در HDF5 با پسوند .h5
در راهنمای Save and serialize models پوشش داده شده است):
# Save the weights
model.save_weights('./checkpoints/my_checkpoint')
# Create a new model instance
model = create_model()
# Restore the weights
model.load_weights('./checkpoints/my_checkpoint')
# Evaluate the model
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
32/32 - 0s - loss: 0.4996 - sparse_categorical_accuracy: 0.8770 - 143ms/epoch - 4ms/step Restored model, accuracy: 87.70%
کل مدل را ذخیره کنید
برای ذخیره معماری، وزن و پیکربندی آموزشی مدل در یک فایل/پوشه واحد، model.save
را فراخوانی کنید. این به شما امکان می دهد یک مدل را صادر کنید تا بدون دسترسی به کد اصلی پایتون* از آن استفاده کنید. از آنجایی که حالت بهینه ساز بازیابی شده است، می توانید تمرین را دقیقا از همان جایی که متوقف کرده اید از سر بگیرید.
کل مدل را می توان در دو فرمت فایل مختلف ( SavedModel
و HDF5
) ذخیره کرد. فرمت SavedModel
فرمت فایل پیش فرض در TF2.x است. با این حال، مدل ها را می توان در فرمت HDF5
ذخیره کرد. جزئیات بیشتر در مورد ذخیره کل مدل ها در دو فرمت فایل در زیر توضیح داده شده است.
ذخیره یک مدل کاملاً کاربردی بسیار مفید است—شما می توانید آنها را در TensorFlow.js بارگذاری کنید ( مدل ذخیره شده ، HDF5 ) و سپس آنها را آموزش دهید و در مرورگرهای وب اجرا کنید، یا با استفاده از TensorFlow Lite آنها را برای اجرا در دستگاه های تلفن همراه تبدیل کنید ( مدل ذخیره شده ، HDF5 )
*اشیاء سفارشی (مثلاً مدلها یا لایههای طبقهبندیشده) به توجه ویژه در هنگام ذخیره و بارگذاری نیاز دارند. بخش ذخیره اشیاء سفارشی را در زیر ببینید
فرمت SavedModel
فرمت SavedModel راه دیگری برای سریال سازی مدل ها است. مدل های ذخیره شده در این قالب را می توان با استفاده از tf.keras.models.load_model
بازیابی کرد و با سرویس TensorFlow سازگار است. راهنمای SavedModel به جزئیات نحوه سرویس/بازرسی SavedModel میپردازد. بخش زیر مراحل ذخیره و بازیابی مدل را نشان می دهد.
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model as a SavedModel.
!mkdir -p saved_model
model.save('saved_model/my_model')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1988 - sparse_categorical_accuracy: 0.6550 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4180 - sparse_categorical_accuracy: 0.8930 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2900 - sparse_categorical_accuracy: 0.9220 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2070 - sparse_categorical_accuracy: 0.9540 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1593 - sparse_categorical_accuracy: 0.9630 2022-01-26 07:30:22.888387: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate WARNING:tensorflow:Detecting that an object or model or tf.train.Checkpoint is being deleted with unrestored values. See the following logs for the specific values in question. To silence these warnings, use `status.expect_partial()`. See https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#restorefor details about the status object returned by the restore function. WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.iter WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_1 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.beta_2 WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.decay WARNING:tensorflow:Value in checkpoint could not be found in the restored object: (root).optimizer.learning_rate INFO:tensorflow:Assets written to: saved_model/my_model/assets
فرمت SavedModel یک دایرکتوری حاوی یک پروتوباف باینری و یک چک پوینت TensorFlow است. دایرکتوری مدل ذخیره شده را بررسی کنید:
# my_model directory
ls saved_model
# Contains an assets folder, saved_model.pb, and variables folder.
ls saved_model/my_model
my_model assets keras_metadata.pb saved_model.pb variables
یک مدل Keras تازه از مدل ذخیره شده را بارگیری مجدد کنید:
new_model = tf.keras.models.load_model('saved_model/my_model')
# Check its architecture
new_model.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_10 (Dense) (None, 512) 401920 dropout_5 (Dropout) (None, 512) 0 dense_11 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
مدل بازیابی شده با همان آرگومان های مدل اصلی کامپایل می شود. اجرای ارزیابی و پیش بینی با مدل بارگذاری شده را امتحان کنید:
# Evaluate the restored model
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
32/32 - 0s - loss: 0.4577 - sparse_categorical_accuracy: 0.8430 - 156ms/epoch - 5ms/step Restored model, accuracy: 84.30% (1000, 10)
فرمت HDF5
Keras یک فرمت ذخیره اولیه را با استفاده از استاندارد HDF5 ارائه می دهد.
# Create and train a new model instance.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# Save the entire model to a HDF5 file.
# The '.h5' extension indicates that the model should be saved to HDF5.
model.save('my_model.h5')
Epoch 1/5 32/32 [==============================] - 0s 2ms/step - loss: 1.1383 - sparse_categorical_accuracy: 0.6970 Epoch 2/5 32/32 [==============================] - 0s 2ms/step - loss: 0.4094 - sparse_categorical_accuracy: 0.8920 Epoch 3/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2936 - sparse_categorical_accuracy: 0.9160 Epoch 4/5 32/32 [==============================] - 0s 2ms/step - loss: 0.2050 - sparse_categorical_accuracy: 0.9460 Epoch 5/5 32/32 [==============================] - 0s 2ms/step - loss: 0.1485 - sparse_categorical_accuracy: 0.9690
حالا مدل را از آن فایل دوباره بسازید:
# Recreate the exact same model, including its weights and the optimizer
new_model = tf.keras.models.load_model('my_model.h5')
# Show the model architecture
new_model.summary()
Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense_12 (Dense) (None, 512) 401920 dropout_6 (Dropout) (None, 512) 0 dense_13 (Dense) (None, 10) 5130 ================================================================= Total params: 407,050 Trainable params: 407,050 Non-trainable params: 0 _________________________________________________________________
دقت آن را بررسی کنید:
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
32/32 - 0s - loss: 0.4266 - sparse_categorical_accuracy: 0.8620 - 141ms/epoch - 4ms/step Restored model, accuracy: 86.20%
Keras مدل ها را با بررسی معماری آنها ذخیره می کند. این تکنیک همه چیز را نجات می دهد:
- ارزش های وزنی
- معماری مدل
- پیکربندی آموزشی مدل (آنچه به روش
.compile()
منتقل می کنید) - بهینه ساز و وضعیت آن، در صورت وجود (این به شما امکان می دهد آموزش را از جایی که متوقف کردید دوباره شروع کنید)
Keras قادر به ذخیره بهینه سازهای v1.x
(از tf.compat.v1.train
) نیست زیرا با نقاط بازرسی سازگار نیستند. برای بهینهسازهای v1.x، باید پس از بارگذاری، مدل را دوباره کامپایل کنید—از دست دادن حالت بهینهساز.
ذخیره اشیاء سفارشی
اگر از فرمت SavedModel استفاده می کنید، می توانید از این بخش صرف نظر کنید. تفاوت اصلی بین HDF5 و SavedModel این است که HDF5 از تنظیمات شی برای ذخیره معماری مدل استفاده می کند، در حالی که SavedModel نمودار اجرا را ذخیره می کند. بنابراین، SavedModels میتواند اشیاء سفارشی مانند مدلهای زیر کلاس و لایههای سفارشی را بدون نیاز به کد اصلی ذخیره کند.
برای ذخیره اشیاء سفارشی در HDF5، باید موارد زیر را انجام دهید:
- یک متد
get_config
را در شیء خود تعریف کنید، و به صورت اختیاری یکfrom_config
config.-
get_config(self)
یک فرهنگ لغت قابل سریالسازی با JSON از پارامترهای مورد نیاز برای ایجاد مجدد شی را برمیگرداند. -
from_config(cls, config)
از پیکربندی برگشتی ازget_config
برای ایجاد یک شی جدید استفاده می کند. به طور پیش فرض، این تابع از پیکربندی به عنوان کوارگ های اولیه استفاده می کند (return cls(**config)
).
-
- هنگام بارگذاری مدل، شی را به آرگومان
custom_objects
کنید. آرگومان باید یک فرهنگ لغت باشد که نام کلاس رشته را به کلاس پایتون نگاشت می کند. به عنوان مثالtf.keras.models.load_model(path, custom_objects={'CustomLayer': CustomLayer})
برای نمونه هایی از اشیاء سفارشی و get_config
به آموزش نوشتن لایه ها و مدل ها از ابتدا مراجعه کنید.
# MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.