הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת |
סקירה כללית
מקובל לשמור ולטעון דגם במהלך האימון. ישנן שתי קבוצות של ממשקי API לשמירה וטעינה של מודל Keras: API ברמה גבוהה וממשק API ברמה נמוכה. מדריך זה מדגים כיצד ניתן להשתמש בממשקי ה-API של SavedModel בעת שימוש ב- tf.distribute.Strategy
. כדי ללמוד על SavedModel והסדרה באופן כללי, אנא קרא את מדריך המודלים השמורים ואת המדריך להסדרת המודלים של Keras . נתחיל בדוגמה פשוטה:
תלות בייבוא:
import tensorflow_datasets as tfds
import tensorflow as tf
הכן את הנתונים והמודל באמצעות tf.distribute.Strategy
:
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
אימון הדגם:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 1/2 2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 938/938 [==============================] - 11s 5ms/step - loss: 0.1873 - sparse_categorical_accuracy: 0.9451 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0641 - sparse_categorical_accuracy: 0.9807 <keras.callbacks.History at 0x7f3b900396d0>
שמור וטען את הדגם
עכשיו, כשיש לך מודל פשוט לעבוד איתו, בואו נסתכל על ממשקי ה-API של שמירה/טעינה. קיימות שתי קבוצות של ממשקי API זמינים:
-
model.save
model.save ברמה גבוהה ו-tf.keras.models.load_model
- tf.saved_model.save ו-
tf.saved_model.save
ברמהtf.saved_model.load
ממשקי API של Keras
הנה דוגמה לשמירה וטעינה של דגם עם ממשקי API של Keras:
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
שחזר את הדגם ללא tf.distribute.Strategy
:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9859 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9895 <keras.callbacks.History at 0x7f3b187b7150>
לאחר שחזור המודל, ניתן להמשיך ולהתאמן עליו, גם ללא צורך לקרוא שוב compile()
, מכיוון שהוא כבר קומפילד לפני השמירה. הדגם נשמר בפורמט פרוטו SavedModel
הסטנדרטי של TensorFlow. למידע נוסף, עיין במדריך לפורמט saved_model
.
כעת לטעון את המודל ולאמן אותו באמצעות tf.distribute.Strategy
:
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. 938/938 [==============================] - 10s 10ms/step - loss: 0.0474 - sparse_categorical_accuracy: 0.9860 Epoch 2/2 938/938 [==============================] - 10s 10ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9903
כפי שאתה יכול לראות, הטעינה עובדת כצפוי עם tf.distribute.Strategy
. האסטרטגיה המשמשת כאן לא חייבת להיות אותה האסטרטגיה שבה נעשה שימוש לפני השמירה.
ממשקי ה-API של tf.saved_model
עכשיו בואו נסתכל על ממשקי ה-API ברמה נמוכה יותר. שמירת הדגם דומה ל-Keras API:
model = get_model() # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
הטעינה יכולה להתבצע באמצעות tf.saved_model.load()
. עם זאת, מכיוון שמדובר ב-API שנמצא ברמה הנמוכה יותר (ולכן יש לו מגוון רחב יותר של מקרי שימוש), הוא לא מחזיר דגם של Keras. במקום זאת, הוא מחזיר אובייקט שמכיל פונקציות שניתן להשתמש בהן כדי להסיק. לדוגמה:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
האובייקט הנטען עשוי להכיל מספר פונקציות, כל אחת משויכת למפתח. ה- "serving_default"
הוא מפתח ברירת המחדל של פונקציית ההסקה עם מודל Keras שמור. כדי לעשות מסקנות עם הפונקציה הזו:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-1.18789300e-01, -1.78404614e-01, 4.92432676e-02, -9.37875658e-02, 1.14302970e-01, -8.99422392e-02, 9.47709680e-02, -7.75382966e-02, 4.04430032e-02, 2.41404288e-02], [-2.35370561e-01, -3.39397341e-02, 2.73427293e-02, -1.08200148e-01, 5.10682352e-02, 1.36142194e-01, 9.28785652e-02, -5.35808355e-02, 2.56292164e-01, 1.05301209e-01], [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02, -1.99329913e-01, -7.45072216e-02, 2.42738128e-02, 2.07733169e-01, -3.15396488e-03, 4.95976806e-02, 2.14848563e-01], [-9.82482210e-02, -6.13910556e-02, 1.00815810e-01, -1.87558904e-01, 1.14685424e-01, 1.53835595e-01, 1.85714245e-01, -8.74890238e-02, 1.07493028e-01, 1.57510787e-02], [-8.56257528e-02, 3.23683321e-02, -3.66768315e-02, -1.47201523e-01, -5.31517603e-02, 1.52744055e-02, 1.69184029e-01, -5.42814359e-02, 1.11524366e-01, 5.65215349e-02], [-1.50604844e-01, -7.87255913e-03, 1.26651973e-01, -1.24476865e-01, 6.94983900e-02, 4.27672639e-03, 1.86136231e-01, -4.54714149e-03, 9.12746191e-02, 6.12779632e-02], [-2.79157639e-01, -4.61089313e-02, 2.51544192e-02, -1.79003477e-01, 3.83432880e-02, 2.05054253e-01, -8.25636461e-03, -8.25546682e-03, 2.41342247e-01, 8.24805871e-02], [-1.42795354e-01, 6.54597580e-02, 2.05058958e-02, -1.28471941e-01, 1.10977650e-01, 4.51317504e-02, 2.44124904e-01, 1.90523565e-02, 3.11958641e-02, 6.49511665e-02], [-1.33037239e-01, -2.72594951e-02, 8.09026062e-02, -1.95883229e-01, 1.84634060e-01, 1.00822970e-01, 4.40884084e-02, -6.43826872e-02, 1.47807434e-01, -1.92791894e-02], [-1.43770471e-01, -2.53150351e-02, 4.18904647e-02, -1.02573663e-01, 6.15917407e-02, 7.95702711e-02, 9.27314460e-02, -4.31537181e-02, 4.59018350e-02, 1.02965936e-01], [-1.90395206e-01, 2.93233991e-03, 1.48900077e-02, -1.15877971e-01, 1.06598288e-02, 1.40121073e-01, 6.86443001e-02, -4.61921766e-02, 1.27470195e-01, 6.73005953e-02], [-2.60747373e-01, -1.45188004e-01, 7.10044056e-04, -1.04602516e-01, 5.00324890e-02, 2.96664417e-01, 8.57191086e-02, 6.65097907e-02, 1.31302923e-01, -1.84605196e-02], [-1.62942797e-01, -3.63466889e-02, -1.33987352e-01, -1.34576231e-01, -8.19503814e-02, 1.30840242e-02, 6.16783127e-02, -3.64837795e-02, 3.18005830e-02, 1.98420882e-01], [-1.25772715e-01, -6.94367215e-02, -1.35144517e-02, -6.30265176e-02, 8.36028308e-02, 2.96559408e-02, 2.19864860e-01, -7.08417147e-02, 4.76131588e-02, 1.15781695e-01], [-1.55139655e-01, -1.27863720e-01, 9.67459157e-02, -1.48635745e-01, 1.25129193e-01, 4.04443927e-02, 2.94884086e-01, -7.66484886e-02, 1.18753463e-01, 2.93397382e-02], [-1.59221828e-01, -9.30457860e-02, 9.18259323e-02, -1.72857821e-01, 8.09611157e-02, 1.11391053e-01, 1.66679412e-01, 3.52456123e-02, 9.05358568e-02, 9.89414975e-02], [-2.01425552e-01, -4.67008501e-02, -1.62331611e-02, -9.73629057e-02, 1.36456266e-01, 1.30628154e-01, 1.53577864e-01, -6.73157908e-03, 9.31103677e-02, 1.50734074e-02], [-1.29348308e-01, -3.03804129e-03, 2.82487050e-02, -2.02886015e-01, 7.09105879e-02, 1.74542382e-01, 2.57992335e-02, -1.63579211e-02, 2.30892301e-02, 6.69767857e-02], [-1.56857669e-01, 5.46110943e-02, -5.93251809e-02, -1.04585059e-01, 2.61763521e-02, 1.43062070e-01, 1.57771498e-01, -6.19823262e-02, 3.59585434e-02, 6.62322640e-02], [-8.64257440e-02, -1.33483298e-03, 7.46414512e-02, -1.82848468e-01, 1.21074423e-01, 1.55276239e-01, 1.46483868e-01, -6.22515939e-03, 1.91641584e-01, -9.95825827e-02], [-2.52117336e-01, -6.92471862e-02, 1.09911412e-01, -3.73112522e-02, 3.76211852e-03, 5.23591004e-02, 9.16506499e-02, 6.80204183e-02, -4.27842364e-02, 7.91264027e-02], [-2.11018056e-01, 5.97522780e-03, 8.47486481e-02, -7.27925971e-02, 9.36664082e-03, 1.62506998e-01, 5.32426499e-02, 1.78599171e-02, -2.30420940e-02, 4.07365486e-02], [-1.35342121e-01, -4.06659022e-02, -2.09493563e-02, -1.64699793e-01, 8.35808069e-02, 7.68100768e-02, -7.14773983e-02, -3.43702435e-02, 9.47649628e-02, 9.36352089e-02], [-1.20486066e-01, 3.77080180e-02, 1.14158325e-01, -6.50681928e-02, 1.03382617e-02, 1.17891498e-01, 1.13154747e-01, -1.49052702e-02, 1.28893867e-01, 1.12219512e-01], [-2.23867983e-01, -9.79400948e-02, 7.37103820e-02, -1.05197895e-02, 3.75595838e-02, 1.80490598e-01, 6.83145374e-02, -3.09509300e-02, 1.42565176e-01, 8.05927664e-02], [-2.32092351e-01, -3.42734642e-02, -5.15977889e-02, -1.75458089e-01, 1.46448284e-01, 1.80426955e-01, 1.52164772e-01, -2.57370695e-02, 1.26812875e-01, 1.22049123e-01], [-9.45013613e-02, 5.85526973e-02, 1.47456676e-02, -4.40606587e-02, 4.86647561e-02, 6.28624633e-02, 3.69989276e-02, -3.68277319e-02, 3.56127135e-02, 3.10502797e-02], [-1.02712311e-01, 3.16979140e-02, 1.88253060e-01, -5.99608906e-02, 3.73450294e-02, 6.38176724e-02, 1.12240583e-01, 2.42183693e-02, 1.45670772e-02, -9.52028483e-03], [-1.62333213e-02, -1.42737105e-02, -5.79352975e-02, -1.01807326e-01, -7.93362781e-03, -7.22003728e-02, 1.49934232e-01, -1.19943202e-01, 9.22369361e-02, 1.46321565e-01], [-1.32534593e-01, 1.18380897e-02, 2.23980099e-03, -9.28303748e-02, -2.20538303e-02, 7.68908709e-02, 5.29715866e-02, -3.43324393e-02, -1.27909705e-02, -7.04141408e-02], [-8.10261145e-02, -8.95578321e-03, 3.96864787e-02, -1.21861629e-01, 7.98310041e-02, 1.56087667e-01, 9.11872089e-02, -2.29295418e-02, 5.64432219e-02, -3.55931222e-02], [-1.76416740e-01, 1.12043694e-02, -1.80068091e-02, -1.88012689e-01, 8.68914276e-02, 1.57958359e-01, 5.77907935e-02, -2.12088451e-02, 5.33877537e-02, 2.19271183e-02], [-2.70012528e-01, -1.26611829e-01, 3.10387388e-02, -7.24840909e-02, 1.03253610e-01, 8.91268626e-02, 1.38662308e-01, -6.25240132e-02, 2.36210316e-01, 1.40534222e-01], [-8.52961093e-02, -1.15273651e-02, -2.88792588e-02, -2.01282576e-02, 5.43357767e-02, 7.14191943e-02, 3.46604213e-02, -6.00920171e-02, 5.11362031e-02, 3.58160883e-02], [-1.63262367e-01, 2.44849995e-02, 3.81964818e-02, -3.93010303e-02, 3.95263731e-03, 9.11088511e-02, 3.88236046e-02, 1.33745335e-02, 1.00076631e-01, 6.05135933e-02], [-3.01809371e-01, -1.58440098e-01, 4.65333983e-02, -1.63946241e-01, -6.42775744e-02, 3.93286347e-04, 2.82839835e-01, -8.93663988e-02, 1.97781295e-01, 2.87044942e-01], [-2.15368003e-01, -4.83291782e-02, -8.29075277e-03, -1.01776704e-01, 1.43144801e-02, 1.82002857e-02, 2.76539754e-02, -1.94141679e-02, 8.87098238e-02, 6.60644472e-02], [-2.20715180e-01, -7.20694065e-02, -6.08972833e-02, -4.82957587e-02, 1.28858402e-01, 1.30042464e-01, 1.32807568e-01, -7.52742141e-02, 9.51702446e-02, 3.10119465e-02], [-1.09407350e-01, -5.27948700e-03, 1.29588693e-03, -2.61662379e-02, 3.01920641e-02, 1.13487415e-01, 8.23267922e-02, 1.92574020e-02, 2.31986474e-02, 4.13139611e-02], [-2.12277412e-01, -1.35507256e-01, 4.22930568e-02, -1.34565741e-01, 1.17879853e-01, 1.30573064e-01, 1.81054786e-01, -1.70722306e-01, 1.05854876e-01, 7.36362934e-02], [-1.78249478e-01, -7.55607188e-02, 7.75147527e-02, -2.14659080e-01, 3.26948166e-02, 7.76198730e-02, 1.08791113e-01, -2.38809325e-02, 1.79410487e-01, 1.94452941e-01], [-1.92162693e-01, -1.50472090e-01, -8.24331492e-02, -1.40473023e-02, 3.60646360e-02, -9.39090401e-02, 1.83859855e-01, -1.09493822e-01, -3.09051797e-02, 1.36017531e-01], [-9.21519399e-02, -1.53335631e-02, -5.56742400e-02, -9.68495384e-02, 2.35293470e-02, 2.53665410e-02, 1.79999322e-01, -7.10204691e-02, -7.29817525e-02, 4.50368747e-02], [-1.22261971e-01, -6.94630146e-02, -7.97796808e-03, -1.03088826e-01, -7.38603100e-02, 1.84892826e-02, 9.76646394e-02, -3.29037756e-02, -1.77134499e-02, 1.62288889e-01], [-6.78652674e-02, -1.08500615e-01, 5.66991530e-02, -9.52370912e-02, 5.28126955e-02, 1.05176866e-02, 1.73085481e-01, -1.37753151e-02, 1.95556954e-02, 1.38068855e-01], [-2.02808753e-01, -3.39423120e-02, 1.82233751e-03, -5.71424365e-02, 3.40205729e-02, 8.74454305e-02, 8.47227685e-03, -2.52498202e-02, 4.66104299e-02, 1.10718749e-01], [-9.52449068e-02, -3.35062481e-02, -1.00178778e-01, -9.72513855e-02, -3.58061343e-02, 3.04423086e-02, 5.70362583e-02, -4.03833576e-02, -4.28436548e-02, 9.73245874e-02], [-2.06081957e-01, -1.71493232e-01, 2.52560824e-02, -1.55212343e-01, -4.33478206e-02, 2.34177694e-01, 8.46128762e-02, 1.75322518e-02, 2.04347119e-01, 1.54971585e-01], [-1.95310384e-01, 1.30968075e-02, -9.68117267e-03, -7.31432810e-02, 1.02618083e-01, 1.59629256e-01, 1.66028887e-01, -7.12903216e-03, 1.78021699e-01, -2.17130631e-02], [-1.59163624e-01, -1.77137554e-05, 1.75410658e-02, -9.08103511e-02, 7.25786015e-02, 9.21041369e-02, 1.24915361e-01, -6.55939505e-02, -1.13440230e-02, 1.03661232e-01], [-1.93366870e-01, -4.36344892e-02, 1.37750164e-01, -1.91939399e-01, -1.50268525e-03, 8.03942382e-02, 2.15812266e-01, 5.38492575e-02, 1.36685073e-01, 2.22119391e-01], [-1.65946245e-01, 7.89588690e-03, -1.65037125e-01, -1.23690292e-01, -8.57629776e-02, -2.55736727e-02, 1.67541012e-01, -6.63827211e-02, 2.98694819e-02, 1.71927184e-01], [-1.56264767e-01, -1.72245800e-02, -4.98924702e-02, -2.98387632e-02, 2.80477256e-02, 4.94132042e-02, 4.89805043e-02, 1.96998678e-02, -4.14144360e-02, -5.05549274e-02], [-1.46449029e-01, -1.12528354e-01, -4.66653258e-02, -3.78398523e-02, 7.60737807e-03, -2.70657167e-02, 1.11277811e-01, 6.37479573e-02, -2.39458829e-02, 1.22067556e-01], [-1.92323536e-01, -1.43002480e-01, 5.29062748e-03, -1.70663983e-01, 8.39572400e-03, 6.37906119e-02, 1.24084033e-01, 6.02792688e-02, 7.18353763e-02, 5.03963791e-03], [-1.70977920e-01, 1.04207098e-02, 1.18544906e-01, -4.29532528e-02, -3.53983864e-02, 1.80302024e-01, 8.08775946e-02, 3.19045782e-02, 2.52931342e-02, 1.29424319e-01], [-2.13301033e-01, -6.96119964e-02, 2.32847631e-02, -7.73920864e-02, 1.10387571e-01, 1.13307782e-01, 1.41805351e-01, -5.19381016e-02, 1.15313083e-01, 1.40049949e-01], [-1.71651557e-01, -5.98860830e-02, -3.92800570e-03, -1.04376137e-01, 7.78115019e-02, 6.84583709e-02, 2.51923770e-01, -1.05199262e-01, 1.64517179e-01, 2.18875334e-01], [-2.60777414e-01, -8.93031508e-02, 1.27723843e-01, -1.97950065e-01, 1.19145498e-01, 7.30907321e-02, 2.23771721e-01, -6.83849230e-02, 3.68930906e-01, 1.86811388e-01], [-2.38028213e-01, 1.11199915e-03, 2.25015372e-01, 8.22724327e-02, -1.14511400e-01, 1.57513067e-01, 5.22858277e-02, 2.13724375e-03, 3.15639377e-02, 2.08704025e-01], [-1.46687120e-01, -1.10313833e-01, -1.16352811e-02, -1.44550815e-01, 2.09794566e-02, 1.47883072e-02, 3.96856442e-02, -2.15019658e-03, -4.90810722e-02, 1.34708211e-01], [-2.02591017e-01, -2.29728431e-01, 6.73423260e-02, -1.24901496e-01, -1.38434023e-02, 8.64367038e-02, 1.22342721e-01, 1.67826824e-02, 1.65354639e-01, 1.83434993e-01], [-2.25799978e-01, -1.02682747e-01, 9.48531851e-02, -9.38871950e-02, 1.03806734e-01, 2.04695478e-01, 8.09893832e-02, -1.45416632e-02, 1.33486420e-01, -6.27665371e-02], [-1.19375348e-01, 2.23235339e-02, 1.04302749e-01, -1.11149743e-01, 6.12434298e-02, 6.89433664e-02, 2.08741099e-01, -3.81497070e-02, -1.42122135e-02, 7.65201449e-03]], dtype=float32)>} 2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
אתה יכול גם לטעון ולעשות מסקנות בצורה מבוזרת:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) 2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
קריאה לפונקציה המשוחזרת היא רק העברה קדימה בדגם השמור (חיזוי). מה אם אתה רוצה להמשיך לאמן את הפונקציה העמוסה? או להטמיע את הפונקציה הטעונה בדגם גדול יותר? נוהג נפוץ הוא לעטוף את האובייקט הטעון הזה לשכבת Keras כדי להשיג זאת. למרבה המזל, ל- TF Hub יש hub.KerasLayer למטרה זו, המוצג כאן:
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Epoch 1/2 2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 938/938 [==============================] - 6s 3ms/step - loss: 0.1910 - sparse_categorical_accuracy: 0.9442 Epoch 2/2 938/938 [==============================] - 3s 4ms/step - loss: 0.0633 - sparse_categorical_accuracy: 0.9813
כפי שאתה יכול לראות, hub.KerasLayer
עוטף את התוצאה שנטענה בחזרה מ- tf.saved_model.load()
לשכבת Keras שניתן להשתמש בה כדי לבנות מודל נוסף. זה מאוד שימושי ללמידה בהעברה.
באיזה API עלי להשתמש?
לצורך שמירה, אם אתה עובד עם מודל keras, כמעט תמיד מומלץ להשתמש ב-API model.save()
של Keras. אם מה שאתה חוסך הוא לא דגם של Keras, אז ה-API ברמה נמוכה יותר הוא הבחירה היחידה שלך.
לטעינה, באיזה API אתה משתמש תלוי במה שאתה רוצה לקבל מ-API לטעינה. אם אינך יכול (או לא רוצה) לקבל מודל Keras, השתמש ב- tf.saved_model.load()
. אחרת, השתמש ב- tf.keras.models.load_model()
. שימו לב שאתם יכולים לקבל בחזרה דגם של קרס רק אם שמרתם דגם של קרס.
אפשר לערבב ולהתאים את ה-API. אתה יכול לשמור דגם של Keras עם model.save
, ולטעון דגם שאינו Keras עם ה-API ברמה נמוכה, tf.saved_model.load
.
model = get_model()
# Saving the model using Keras's save() API
model.save(keras_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
שמירה/טעינה ממכשיר מקומי
בעת שמירה וטעינה מהתקן io מקומי תוך כדי ריצה מרחוק, למשל באמצעות TPU בענן, יש להשתמש באפשרות experimental_io_device
כדי להגדיר את מכשיר ה-io ל-localhost.
model = get_model()
# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)
# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
אזהרות
מקרה מיוחד הוא כאשר יש לך דגם של Keras שאין לו כניסות מוגדרות היטב. לדוגמה, ניתן ליצור מודל Sequential ללא צורות קלט כלשהן ( Sequential([Dense(3), ...]
). גם למודלים מסווגים אין כניסות מוגדרות היטב לאחר האתחול. במקרה זה, עליך לדבוק ב- ממשקי API ברמה נמוכה יותר הן בשמירה והן בטעינה, אחרת תקבל שגיאה.
כדי לבדוק אם לדגם שלך יש כניסות מוגדרות היטב, פשוט בדוק אם model.inputs
הוא None
. אם זה לא None
, כולכם טובים. צורות קלט מוגדרות אוטומטית כאשר המודל משמש ב-. .predict
, .fit
, .evaluate
או בעת קריאה למודל ( model(inputs)
).
הנה דוגמה:
class SubclassedModel(tf.keras.Model):
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
# my_model.save(keras_model_path) # ERROR!
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets