TensorFlow.org-এ দেখুন | Google Colab-এ চালান | GitHub-এ উৎস দেখুন | নোটবুক ডাউনলোড করুন |
ওভারভিউ
প্রশিক্ষণের সময় একটি মডেল সংরক্ষণ এবং লোড করা সাধারণ। কেরাস মডেল সংরক্ষণ এবং লোড করার জন্য API-এর দুটি সেট রয়েছে: একটি উচ্চ-স্তরের API, এবং একটি নিম্ন-স্তরের API। এই টিউটোরিয়ালটি দেখায় কিভাবে আপনি tf.distribute.Strategy ব্যবহার করার সময় tf.distribute.Strategy
API ব্যবহার করতে পারেন। SavedModel এবং সাধারণভাবে সিরিয়ালাইজেশন সম্পর্কে জানতে, অনুগ্রহ করে সংরক্ষিত মডেল গাইড এবং Keras মডেল সিরিয়ালাইজেশন গাইড পড়ুন। একটি সহজ উদাহরণ দিয়ে শুরু করা যাক:
আমদানি নির্ভরতা:
import tensorflow_datasets as tfds
import tensorflow as tf
tf.distribute.Strategy
ব্যবহার করে ডেটা এবং মডেল প্রস্তুত করুন:
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets, ds_info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
মডেলকে প্রশিক্ষণ দিন:
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). Epoch 1/2 2022-01-26 05:41:11.916000: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 938/938 [==============================] - 11s 5ms/step - loss: 0.1873 - sparse_categorical_accuracy: 0.9451 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0641 - sparse_categorical_accuracy: 0.9807 <keras.callbacks.History at 0x7f3b900396d0>
সংরক্ষণ করুন এবং মডেল লোড
এখন যেহেতু আপনার কাছে কাজ করার জন্য একটি সাধারণ মডেল আছে, চলুন সেভিং/লোডিং API গুলো দেখে নেওয়া যাক। API এর দুটি সেট উপলব্ধ রয়েছে:
- উচ্চ স্তরের
model.save
এবংtf.keras.models.load_model
- নিম্ন স্তরের
tf.saved_model.save
এবংtf.saved_model.load
কেরাস এপিআই
কেরাস এপিআইগুলির সাথে একটি মডেল সংরক্ষণ এবং লোড করার একটি উদাহরণ এখানে রয়েছে:
keras_model_path = "/tmp/keras_save"
model.save(keras_model_path)
2022-01-26 05:41:26.593570: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them. INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets
tf.distribute.Strategy
ছাড়াই মডেলটি পুনরুদ্ধার করুন:
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0476 - sparse_categorical_accuracy: 0.9859 Epoch 2/2 938/938 [==============================] - 3s 3ms/step - loss: 0.0334 - sparse_categorical_accuracy: 0.9895 <keras.callbacks.History at 0x7f3b187b7150>
মডেল পুনরুদ্ধার করার পরে, আপনি এটির প্রশিক্ষণ চালিয়ে যেতে পারেন, এমনকি compile()
কে আবার কল করার প্রয়োজন ছাড়াই, যেহেতু এটি সংরক্ষণ করার আগে ইতিমধ্যেই সংকলিত হয়েছে। মডেলটি TensorFlow এর স্ট্যান্ডার্ড SavedModel
প্রোটো ফরম্যাটে সংরক্ষিত হয়েছে। আরও তথ্যের জন্য, অনুগ্রহ করে saved_model
ফরম্যাটের নির্দেশিকা পড়ুন।
এখন মডেলটি লোড করতে এবং একটি tf.distribute.Strategy
ব্যবহার করে প্রশিক্ষণ দিতে:
another_strategy = tf.distribute.OneDeviceStrategy("/cpu:0")
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 2022-01-26 05:41:33.036733: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 2022-01-26 05:41:33.083001: W tensorflow/core/framework/dataset.cc:768] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. 938/938 [==============================] - 10s 10ms/step - loss: 0.0474 - sparse_categorical_accuracy: 0.9860 Epoch 2/2 938/938 [==============================] - 10s 10ms/step - loss: 0.0327 - sparse_categorical_accuracy: 0.9903
আপনি দেখতে পাচ্ছেন, লোডিং tf.distribute.Strategy
সাথে প্রত্যাশিতভাবে কাজ করে। এখানে ব্যবহৃত কৌশল সংরক্ষণ করার আগে ব্যবহৃত একই কৌশল হতে হবে না।
tf.saved_model
APIs
এখন নিম্ন স্তরের API গুলো দেখে নেওয়া যাক। মডেলটি সংরক্ষণ করা কেরাস API এর অনুরূপ:
model = get_model() # get a fresh model
saved_model_path = "/tmp/tf_save"
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
লোডিং tf.saved_model.load()
দিয়ে করা যেতে পারে। যাইহোক, যেহেতু এটি একটি এপিআই যা নিম্ন স্তরে রয়েছে (এবং তাই এর ব্যবহারের ক্ষেত্রে বিস্তৃত পরিসর রয়েছে), এটি কেরাস মডেল ফেরত দেয় না। পরিবর্তে, এটি এমন একটি বস্তু ফেরত দেয় যাতে ফাংশন থাকে যা অনুমান করতে ব্যবহার করা যেতে পারে। উদাহরণ স্বরূপ:
DEFAULT_FUNCTION_KEY = "serving_default"
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
লোড করা বস্তুতে একাধিক ফাংশন থাকতে পারে, প্রতিটি একটি কী-এর সাথে যুক্ত। "serving_default"
হল একটি সংরক্ষিত কেরাস মডেলের সাথে ইনফারেন্স ফাংশনের জন্য ডিফল্ট কী। এই ফাংশনের সাথে একটি অনুমান করতে:
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-1.18789300e-01, -1.78404614e-01, 4.92432676e-02, -9.37875658e-02, 1.14302970e-01, -8.99422392e-02, 9.47709680e-02, -7.75382966e-02, 4.04430032e-02, 2.41404288e-02], [-2.35370561e-01, -3.39397341e-02, 2.73427293e-02, -1.08200148e-01, 5.10682352e-02, 1.36142194e-01, 9.28785652e-02, -5.35808355e-02, 2.56292164e-01, 1.05301209e-01], [-1.91031799e-01, -7.72745535e-02, -7.23153427e-02, -1.99329913e-01, -7.45072216e-02, 2.42738128e-02, 2.07733169e-01, -3.15396488e-03, 4.95976806e-02, 2.14848563e-01], [-9.82482210e-02, -6.13910556e-02, 1.00815810e-01, -1.87558904e-01, 1.14685424e-01, 1.53835595e-01, 1.85714245e-01, -8.74890238e-02, 1.07493028e-01, 1.57510787e-02], [-8.56257528e-02, 3.23683321e-02, -3.66768315e-02, -1.47201523e-01, -5.31517603e-02, 1.52744055e-02, 1.69184029e-01, -5.42814359e-02, 1.11524366e-01, 5.65215349e-02], [-1.50604844e-01, -7.87255913e-03, 1.26651973e-01, -1.24476865e-01, 6.94983900e-02, 4.27672639e-03, 1.86136231e-01, -4.54714149e-03, 9.12746191e-02, 6.12779632e-02], [-2.79157639e-01, -4.61089313e-02, 2.51544192e-02, -1.79003477e-01, 3.83432880e-02, 2.05054253e-01, -8.25636461e-03, -8.25546682e-03, 2.41342247e-01, 8.24805871e-02], [-1.42795354e-01, 6.54597580e-02, 2.05058958e-02, -1.28471941e-01, 1.10977650e-01, 4.51317504e-02, 2.44124904e-01, 1.90523565e-02, 3.11958641e-02, 6.49511665e-02], [-1.33037239e-01, -2.72594951e-02, 8.09026062e-02, -1.95883229e-01, 1.84634060e-01, 1.00822970e-01, 4.40884084e-02, -6.43826872e-02, 1.47807434e-01, -1.92791894e-02], [-1.43770471e-01, -2.53150351e-02, 4.18904647e-02, -1.02573663e-01, 6.15917407e-02, 7.95702711e-02, 9.27314460e-02, -4.31537181e-02, 4.59018350e-02, 1.02965936e-01], [-1.90395206e-01, 2.93233991e-03, 1.48900077e-02, -1.15877971e-01, 1.06598288e-02, 1.40121073e-01, 6.86443001e-02, -4.61921766e-02, 1.27470195e-01, 6.73005953e-02], [-2.60747373e-01, -1.45188004e-01, 7.10044056e-04, -1.04602516e-01, 5.00324890e-02, 2.96664417e-01, 8.57191086e-02, 6.65097907e-02, 1.31302923e-01, -1.84605196e-02], [-1.62942797e-01, -3.63466889e-02, -1.33987352e-01, -1.34576231e-01, -8.19503814e-02, 1.30840242e-02, 6.16783127e-02, -3.64837795e-02, 3.18005830e-02, 1.98420882e-01], [-1.25772715e-01, -6.94367215e-02, -1.35144517e-02, -6.30265176e-02, 8.36028308e-02, 2.96559408e-02, 2.19864860e-01, -7.08417147e-02, 4.76131588e-02, 1.15781695e-01], [-1.55139655e-01, -1.27863720e-01, 9.67459157e-02, -1.48635745e-01, 1.25129193e-01, 4.04443927e-02, 2.94884086e-01, -7.66484886e-02, 1.18753463e-01, 2.93397382e-02], [-1.59221828e-01, -9.30457860e-02, 9.18259323e-02, -1.72857821e-01, 8.09611157e-02, 1.11391053e-01, 1.66679412e-01, 3.52456123e-02, 9.05358568e-02, 9.89414975e-02], [-2.01425552e-01, -4.67008501e-02, -1.62331611e-02, -9.73629057e-02, 1.36456266e-01, 1.30628154e-01, 1.53577864e-01, -6.73157908e-03, 9.31103677e-02, 1.50734074e-02], [-1.29348308e-01, -3.03804129e-03, 2.82487050e-02, -2.02886015e-01, 7.09105879e-02, 1.74542382e-01, 2.57992335e-02, -1.63579211e-02, 2.30892301e-02, 6.69767857e-02], [-1.56857669e-01, 5.46110943e-02, -5.93251809e-02, -1.04585059e-01, 2.61763521e-02, 1.43062070e-01, 1.57771498e-01, -6.19823262e-02, 3.59585434e-02, 6.62322640e-02], [-8.64257440e-02, -1.33483298e-03, 7.46414512e-02, -1.82848468e-01, 1.21074423e-01, 1.55276239e-01, 1.46483868e-01, -6.22515939e-03, 1.91641584e-01, -9.95825827e-02], [-2.52117336e-01, -6.92471862e-02, 1.09911412e-01, -3.73112522e-02, 3.76211852e-03, 5.23591004e-02, 9.16506499e-02, 6.80204183e-02, -4.27842364e-02, 7.91264027e-02], [-2.11018056e-01, 5.97522780e-03, 8.47486481e-02, -7.27925971e-02, 9.36664082e-03, 1.62506998e-01, 5.32426499e-02, 1.78599171e-02, -2.30420940e-02, 4.07365486e-02], [-1.35342121e-01, -4.06659022e-02, -2.09493563e-02, -1.64699793e-01, 8.35808069e-02, 7.68100768e-02, -7.14773983e-02, -3.43702435e-02, 9.47649628e-02, 9.36352089e-02], [-1.20486066e-01, 3.77080180e-02, 1.14158325e-01, -6.50681928e-02, 1.03382617e-02, 1.17891498e-01, 1.13154747e-01, -1.49052702e-02, 1.28893867e-01, 1.12219512e-01], [-2.23867983e-01, -9.79400948e-02, 7.37103820e-02, -1.05197895e-02, 3.75595838e-02, 1.80490598e-01, 6.83145374e-02, -3.09509300e-02, 1.42565176e-01, 8.05927664e-02], [-2.32092351e-01, -3.42734642e-02, -5.15977889e-02, -1.75458089e-01, 1.46448284e-01, 1.80426955e-01, 1.52164772e-01, -2.57370695e-02, 1.26812875e-01, 1.22049123e-01], [-9.45013613e-02, 5.85526973e-02, 1.47456676e-02, -4.40606587e-02, 4.86647561e-02, 6.28624633e-02, 3.69989276e-02, -3.68277319e-02, 3.56127135e-02, 3.10502797e-02], [-1.02712311e-01, 3.16979140e-02, 1.88253060e-01, -5.99608906e-02, 3.73450294e-02, 6.38176724e-02, 1.12240583e-01, 2.42183693e-02, 1.45670772e-02, -9.52028483e-03], [-1.62333213e-02, -1.42737105e-02, -5.79352975e-02, -1.01807326e-01, -7.93362781e-03, -7.22003728e-02, 1.49934232e-01, -1.19943202e-01, 9.22369361e-02, 1.46321565e-01], [-1.32534593e-01, 1.18380897e-02, 2.23980099e-03, -9.28303748e-02, -2.20538303e-02, 7.68908709e-02, 5.29715866e-02, -3.43324393e-02, -1.27909705e-02, -7.04141408e-02], [-8.10261145e-02, -8.95578321e-03, 3.96864787e-02, -1.21861629e-01, 7.98310041e-02, 1.56087667e-01, 9.11872089e-02, -2.29295418e-02, 5.64432219e-02, -3.55931222e-02], [-1.76416740e-01, 1.12043694e-02, -1.80068091e-02, -1.88012689e-01, 8.68914276e-02, 1.57958359e-01, 5.77907935e-02, -2.12088451e-02, 5.33877537e-02, 2.19271183e-02], [-2.70012528e-01, -1.26611829e-01, 3.10387388e-02, -7.24840909e-02, 1.03253610e-01, 8.91268626e-02, 1.38662308e-01, -6.25240132e-02, 2.36210316e-01, 1.40534222e-01], [-8.52961093e-02, -1.15273651e-02, -2.88792588e-02, -2.01282576e-02, 5.43357767e-02, 7.14191943e-02, 3.46604213e-02, -6.00920171e-02, 5.11362031e-02, 3.58160883e-02], [-1.63262367e-01, 2.44849995e-02, 3.81964818e-02, -3.93010303e-02, 3.95263731e-03, 9.11088511e-02, 3.88236046e-02, 1.33745335e-02, 1.00076631e-01, 6.05135933e-02], [-3.01809371e-01, -1.58440098e-01, 4.65333983e-02, -1.63946241e-01, -6.42775744e-02, 3.93286347e-04, 2.82839835e-01, -8.93663988e-02, 1.97781295e-01, 2.87044942e-01], [-2.15368003e-01, -4.83291782e-02, -8.29075277e-03, -1.01776704e-01, 1.43144801e-02, 1.82002857e-02, 2.76539754e-02, -1.94141679e-02, 8.87098238e-02, 6.60644472e-02], [-2.20715180e-01, -7.20694065e-02, -6.08972833e-02, -4.82957587e-02, 1.28858402e-01, 1.30042464e-01, 1.32807568e-01, -7.52742141e-02, 9.51702446e-02, 3.10119465e-02], [-1.09407350e-01, -5.27948700e-03, 1.29588693e-03, -2.61662379e-02, 3.01920641e-02, 1.13487415e-01, 8.23267922e-02, 1.92574020e-02, 2.31986474e-02, 4.13139611e-02], [-2.12277412e-01, -1.35507256e-01, 4.22930568e-02, -1.34565741e-01, 1.17879853e-01, 1.30573064e-01, 1.81054786e-01, -1.70722306e-01, 1.05854876e-01, 7.36362934e-02], [-1.78249478e-01, -7.55607188e-02, 7.75147527e-02, -2.14659080e-01, 3.26948166e-02, 7.76198730e-02, 1.08791113e-01, -2.38809325e-02, 1.79410487e-01, 1.94452941e-01], [-1.92162693e-01, -1.50472090e-01, -8.24331492e-02, -1.40473023e-02, 3.60646360e-02, -9.39090401e-02, 1.83859855e-01, -1.09493822e-01, -3.09051797e-02, 1.36017531e-01], [-9.21519399e-02, -1.53335631e-02, -5.56742400e-02, -9.68495384e-02, 2.35293470e-02, 2.53665410e-02, 1.79999322e-01, -7.10204691e-02, -7.29817525e-02, 4.50368747e-02], [-1.22261971e-01, -6.94630146e-02, -7.97796808e-03, -1.03088826e-01, -7.38603100e-02, 1.84892826e-02, 9.76646394e-02, -3.29037756e-02, -1.77134499e-02, 1.62288889e-01], [-6.78652674e-02, -1.08500615e-01, 5.66991530e-02, -9.52370912e-02, 5.28126955e-02, 1.05176866e-02, 1.73085481e-01, -1.37753151e-02, 1.95556954e-02, 1.38068855e-01], [-2.02808753e-01, -3.39423120e-02, 1.82233751e-03, -5.71424365e-02, 3.40205729e-02, 8.74454305e-02, 8.47227685e-03, -2.52498202e-02, 4.66104299e-02, 1.10718749e-01], [-9.52449068e-02, -3.35062481e-02, -1.00178778e-01, -9.72513855e-02, -3.58061343e-02, 3.04423086e-02, 5.70362583e-02, -4.03833576e-02, -4.28436548e-02, 9.73245874e-02], [-2.06081957e-01, -1.71493232e-01, 2.52560824e-02, -1.55212343e-01, -4.33478206e-02, 2.34177694e-01, 8.46128762e-02, 1.75322518e-02, 2.04347119e-01, 1.54971585e-01], [-1.95310384e-01, 1.30968075e-02, -9.68117267e-03, -7.31432810e-02, 1.02618083e-01, 1.59629256e-01, 1.66028887e-01, -7.12903216e-03, 1.78021699e-01, -2.17130631e-02], [-1.59163624e-01, -1.77137554e-05, 1.75410658e-02, -9.08103511e-02, 7.25786015e-02, 9.21041369e-02, 1.24915361e-01, -6.55939505e-02, -1.13440230e-02, 1.03661232e-01], [-1.93366870e-01, -4.36344892e-02, 1.37750164e-01, -1.91939399e-01, -1.50268525e-03, 8.03942382e-02, 2.15812266e-01, 5.38492575e-02, 1.36685073e-01, 2.22119391e-01], [-1.65946245e-01, 7.89588690e-03, -1.65037125e-01, -1.23690292e-01, -8.57629776e-02, -2.55736727e-02, 1.67541012e-01, -6.63827211e-02, 2.98694819e-02, 1.71927184e-01], [-1.56264767e-01, -1.72245800e-02, -4.98924702e-02, -2.98387632e-02, 2.80477256e-02, 4.94132042e-02, 4.89805043e-02, 1.96998678e-02, -4.14144360e-02, -5.05549274e-02], [-1.46449029e-01, -1.12528354e-01, -4.66653258e-02, -3.78398523e-02, 7.60737807e-03, -2.70657167e-02, 1.11277811e-01, 6.37479573e-02, -2.39458829e-02, 1.22067556e-01], [-1.92323536e-01, -1.43002480e-01, 5.29062748e-03, -1.70663983e-01, 8.39572400e-03, 6.37906119e-02, 1.24084033e-01, 6.02792688e-02, 7.18353763e-02, 5.03963791e-03], [-1.70977920e-01, 1.04207098e-02, 1.18544906e-01, -4.29532528e-02, -3.53983864e-02, 1.80302024e-01, 8.08775946e-02, 3.19045782e-02, 2.52931342e-02, 1.29424319e-01], [-2.13301033e-01, -6.96119964e-02, 2.32847631e-02, -7.73920864e-02, 1.10387571e-01, 1.13307782e-01, 1.41805351e-01, -5.19381016e-02, 1.15313083e-01, 1.40049949e-01], [-1.71651557e-01, -5.98860830e-02, -3.92800570e-03, -1.04376137e-01, 7.78115019e-02, 6.84583709e-02, 2.51923770e-01, -1.05199262e-01, 1.64517179e-01, 2.18875334e-01], [-2.60777414e-01, -8.93031508e-02, 1.27723843e-01, -1.97950065e-01, 1.19145498e-01, 7.30907321e-02, 2.23771721e-01, -6.83849230e-02, 3.68930906e-01, 1.86811388e-01], [-2.38028213e-01, 1.11199915e-03, 2.25015372e-01, 8.22724327e-02, -1.14511400e-01, 1.57513067e-01, 5.22858277e-02, 2.13724375e-03, 3.15639377e-02, 2.08704025e-01], [-1.46687120e-01, -1.10313833e-01, -1.16352811e-02, -1.44550815e-01, 2.09794566e-02, 1.47883072e-02, 3.96856442e-02, -2.15019658e-03, -4.90810722e-02, 1.34708211e-01], [-2.02591017e-01, -2.29728431e-01, 6.73423260e-02, -1.24901496e-01, -1.38434023e-02, 8.64367038e-02, 1.22342721e-01, 1.67826824e-02, 1.65354639e-01, 1.83434993e-01], [-2.25799978e-01, -1.02682747e-01, 9.48531851e-02, -9.38871950e-02, 1.03806734e-01, 2.04695478e-01, 8.09893832e-02, -1.45416632e-02, 1.33486420e-01, -6.27665371e-02], [-1.19375348e-01, 2.23235339e-02, 1.04302749e-01, -1.11149743e-01, 6.12434298e-02, 6.89433664e-02, 2.08741099e-01, -3.81497070e-02, -1.42122135e-02, 7.65201449e-03]], dtype=float32)>} 2022-01-26 05:41:53.590742: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
আপনি একটি বিতরণ পদ্ধতিতে লোড এবং অনুমান করতে পারেন:
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
another_strategy.run(inference_func,args=(batch,))
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) 2022-01-26 05:41:53.931428: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance.
পুনরুদ্ধার করা ফাংশন কল করা সংরক্ষিত মডেলের একটি ফরওয়ার্ড পাস মাত্র (ভবিষ্যদ্বাণী করুন)। আপনি যদি লোড ফাংশন প্রশিক্ষণ চালিয়ে যেতে চান? অথবা একটি বড় মডেল মধ্যে লোড ফাংশন এম্বেড? এটি অর্জনের জন্য একটি সাধারণ অভ্যাস হল এই লোড করা বস্তুটিকে কেরাস স্তরে মোড়ানো। ভাগ্যক্রমে, TF Hub এই উদ্দেশ্যে hub.KerasLayer আছে, এখানে দেখানো হয়েছে:
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Epoch 1/2 2022-01-26 05:41:55.594317: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 938/938 [==============================] - 6s 3ms/step - loss: 0.1910 - sparse_categorical_accuracy: 0.9442 Epoch 2/2 938/938 [==============================] - 3s 4ms/step - loss: 0.0633 - sparse_categorical_accuracy: 0.9813
আপনি দেখতে পাচ্ছেন, hub.KerasLayer
tf.saved_model.load()
থেকে লোড করা ফলাফলটিকে কেরাস লেয়ারে র্যাপ করে যা অন্য মডেল তৈরি করতে ব্যবহার করা যেতে পারে। এটি স্থানান্তর শেখার জন্য খুব দরকারী।
কোন API ব্যবহার করা উচিত?
সংরক্ষণের জন্য, আপনি যদি কেরাস মডেলের সাথে কাজ করেন তবে প্রায় সবসময় model.save()
API ব্যবহার করার পরামর্শ দেওয়া হয়। আপনি যা সংরক্ষণ করছেন তা যদি কেরাস মডেল না হয়, তাহলে নিম্ন স্তরের API আপনার একমাত্র পছন্দ।
লোড করার জন্য, আপনি কোন API ব্যবহার করবেন তা নির্ভর করে আপনি লোডিং API থেকে কী পেতে চান। আপনি যদি কেরাস মডেল পেতে না পারেন (বা করতে চান না), তাহলে tf.saved_model.load()
ব্যবহার করুন। অন্যথায়, tf.keras.models.load_model()
ব্যবহার করুন। মনে রাখবেন যে আপনি কেরাস মডেলটি সংরক্ষণ করলেই আপনি একটি কেরাস মডেল ফিরে পেতে পারেন৷
এপিআইগুলিকে মিশ্রিত করা এবং মিলানো সম্ভব। আপনি model.save
দিয়ে কেরাস মডেল সংরক্ষণ করতে পারেন এবং নিম্ন-স্তরের API, tf.saved_model.load
সহ একটি নন-কেরাস মডেল লোড করতে পারেন।
model = get_model()
# Saving the model using Keras's save() API
model.save(keras_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using lower level API
with another_strategy.scope():
loaded = tf.saved_model.load(keras_model_path)
INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Assets written to: /tmp/keras_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
স্থানীয় ডিভাইস থেকে সংরক্ষণ/লোড হচ্ছে
দূরবর্তীভাবে চালানোর সময় একটি স্থানীয় io ডিভাইস থেকে সংরক্ষণ এবং লোড করার সময়, উদাহরণস্বরূপ একটি ক্লাউড TPU ব্যবহার করে, io ডিভাইসটিকে লোকালহোস্টে সেট করতে experimental_io_device
io_device বিকল্পটি ব্যবহার করা আবশ্যক।
model = get_model()
# Saving the model to a path on localhost.
saved_model_path = "/tmp/tf_save"
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)
# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
সতর্কতা
একটি বিশেষ কেস হল যখন আপনার কাছে একটি কেরাস মডেল থাকে যাতে ভালভাবে সংজ্ঞায়িত ইনপুট নেই৷ উদাহরণস্বরূপ, একটি অনুক্রমিক মডেল কোনো ইনপুট আকার ছাড়াই তৈরি করা যেতে পারে ( Sequential([Dense(3), ...]
)। উপশ্রেণিত মডেলগুলিতেও আরম্ভ করার পরে সু-সংজ্ঞায়িত ইনপুট থাকে না। এই ক্ষেত্রে, আপনার সাথে থাকা উচিত সংরক্ষণ এবং লোড উভয় ক্ষেত্রেই নিম্ন স্তরের API, অন্যথায় আপনি একটি ত্রুটি পাবেন।
আপনার মডেলে সু-সংজ্ঞায়িত ইনপুট আছে কিনা তা পরীক্ষা করতে, model.inputs
None
কিনা তা পরীক্ষা করুন। যদি এটা না None
, আপনি সব ভাল. ইনপুট আকারগুলি স্বয়ংক্রিয়ভাবে সংজ্ঞায়িত হয় যখন মডেলটি .fit
, .evaluate
, .predict
, অথবা মডেলটিকে কল করার সময় ( model(inputs)
) ব্যবহার করা হয়৷
এখানে একটি উদাহরণ:
class SubclassedModel(tf.keras.Model):
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
# my_model.save(keras_model_path) # ERROR!
tf.saved_model.save(my_model, saved_model_path)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f3ad00f3510>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.layers.core.dense.Dense object at 0x7f3ad00f3e90>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets