ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
เมื่อคุณย้ายโมเดลของคุณจากกราฟและเซสชันของ TensorFlow 1 ไปยัง TensorFlow 2 API เช่น tf.function
, tf.Module
และ tf.keras.Model
แล้ว คุณสามารถย้ายข้อมูลการบันทึกโมเดลและการโหลดโค้ดได้ สมุดบันทึกนี้แสดงตัวอย่างวิธีบันทึกและโหลดในรูปแบบ SavedModel ใน TensorFlow 1 และ TensorFlow 2 ต่อไปนี้คือภาพรวมโดยย่อของการเปลี่ยนแปลง API ที่เกี่ยวข้องสำหรับการย้ายจาก TensorFlow 1 เป็น TensorFlow 2:
TensorFlow 1 | การโยกย้ายไปยัง TensorFlow 2 | |
---|---|---|
ประหยัด | tf.compat.v1.saved_model.Builder tf.compat.v1.saved_model.simple_save | tf.saved_model.save Keras: tf.keras.models.save_model |
กำลังโหลด | tf.compat.v1.saved_model.load | tf.saved_model.load Keras: tf.keras.models.load_model |
ลายเซ็น : ชุดอินพุต และเทนเซอร์เอาต์พุตที่ สามารถใช้เพื่อเรียกใช้ | สร้างโดยใช้ *.signature_def utils(เช่น tf.compat.v1.saved_model.predict_signature_def ) | เขียน tf.function และส่งออกโดยใช้อาร์กิวเมนต์ signatures ใน tf.saved_model.save |
การจำแนกประเภท และการถดถอย : ลายเซ็นประเภทพิเศษ | สร้างด้วยtf.compat.v1.saved_model.classification_signature_def ,tf.compat.v1.saved_model.regression_signature_def ,และการส่งออกเครื่องมือประมาณการบางอย่าง | ลายเซ็นทั้งสองประเภทนี้ถูกลบออกจาก TensorFlow 2 แล้ว หากไลบรารีที่ให้บริการต้องการชื่อเมธอดเหล่านี้ tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater |
สำหรับคำอธิบายเชิงลึกเพิ่มเติมของการทำแผนที่ โปรดดูส่วน การเปลี่ยนแปลงจาก TensorFlow 1 เป็น TensorFlow 2 ด้านล่าง
ติดตั้ง
ตัวอย่างด้านล่างแสดงวิธีการส่งออกและโหลดโมเดลจำลอง TensorFlow เดียวกัน (กำหนดเป็น add_two
ด้านล่าง) ไปยังรูปแบบ SavedModel โดยใช้ TensorFlow 1 และ TensorFlow 2 API เริ่มต้นด้วยการตั้งค่าฟังก์ชันการนำเข้าและยูทิลิตี้:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import shutil
def remove_dir(path):
try:
shutil.rmtree(path)
except:
pass
def add_two(input):
return input + 2
TensorFlow 1: บันทึกและส่งออก SavedModel
ใน TensorFlow 1 คุณใช้ tf.compat.v1.saved_model.Builder
, tf.compat.v1.saved_model.simple_save
และ tf.estimator.Estimator.export_saved_model
APIs เพื่อสร้าง บันทึก และส่งออกกราฟและเซสชัน TensorFlow:
1. บันทึกกราฟเป็น SavedModel ด้วย SavedModelBuilder
remove_dir("saved-model-builder")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add two output: ", sess.run(output, {input: 3.}))
# Save with SavedModelBuilder
builder = tf1.saved_model.Builder('saved-model-builder')
sig_def = tf1.saved_model.predict_signature_def(
inputs={'input': input},
outputs={'output': output})
builder.add_meta_graph_and_variables(
sess, tags=["serve"], signature_def_map={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
})
builder.save()
add two output: 5.0 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:208: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: saved-model-builder/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/bin/saved_model_cli", line 8, in <module> sys.exit(main()) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 1211, in main args.func(args) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 769, in run init_tpu=args.init_tpu, tf_debug=args.tf_debug) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 417, in run_saved_model_with_feed_dict tag_set) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_utils.py", line 117, in get_meta_graph_def saved_model = read_saved_model(saved_model_dir) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_utils.py", line 55, in read_saved_model raise IOError("SavedModel file does not exist at: %s" % saved_model_dir) OSError: SavedModel file does not exist at: simple-save
2. สร้าง SavedModel สำหรับให้บริการ
remove_dir("simple-save")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add_two output: ", sess.run(output, {input: 3.}))
tf1.saved_model.simple_save(
sess, 'simple-save',
inputs={'input': input},
outputs={'output': output})
add_two output: 5.0 WARNING:tensorflow:From /tmp/ipykernel_26511/250978412.py:12: simple_save (from tensorflow.python.saved_model.simple_save) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.simple_save. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: simple-save/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Result for output key output: 12.0
3. ส่งออกกราฟการอนุมานของ Estimator เป็น SavedModel
ในคำจำกัดความของ Estimator model_fn
(กำหนดไว้ด้านล่าง) คุณสามารถกำหนดลายเซ็นในโมเดลของคุณโดยส่งคืน export_outputs
ใน tf.estimator.EstimatorSpec
เอาต์พุตมีหลายประเภท:
-
tf.estimator.export.ClassificationOutput
-
tf.estimator.export.RegressionOutput
-
tf.estimator.export.PredictOutput
สิ่งเหล่านี้จะสร้างประเภทการจำแนกประเภท การถดถอย และการทำนายตามลำดับ
เมื่อตัวประมาณถูกส่งออกด้วย tf.estimator.Estimator.export_saved_model
ลายเซ็นเหล่านี้จะถูกบันทึกพร้อมกับโมเดล
def model_fn(features, labels, mode):
output = add_two(features['input'])
step = tf1.train.get_global_step()
return tf.estimator.EstimatorSpec(
mode,
predictions=output,
train_op=step.assign_add(1),
loss=tf.constant(0.),
export_outputs={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
tf.estimator.export.PredictOutput({'output': output})})
est = tf.estimator.Estimator(model_fn, 'estimator-checkpoints')
# Train for one step to create a checkpoint.
def train_fn():
return tf.data.Dataset.from_tensors({'input': 3.})
est.train(train_fn, steps=1)
# This utility function `build_raw_serving_input_receiver_fn` takes in raw
# tensor features and builds an "input serving receiver function", which
# creates placeholder inputs to the model.
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
{'input': tf.constant(3.)}) # Pass in a dummy input batch.
estimator_path = est.export_saved_model('exported-estimator', serving_input_fn)
# Estimator's export_saved_model creates a time stamped directory. Move this
# to a set path so it can be inspected with `saved_model_cli` in the cell below.
!rm -rf estimator-model
import shutil
shutil.move(estimator_path, 'estimator-model')
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': 'estimator-checkpoints', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:401: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.0, step = 1 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Loss for final step: 0.0. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from estimator-checkpoints/model.ckpt-1 INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: exported-estimator/temp-1636162129/saved_model.pb 'estimator-model'
!saved_model_cli run --dir estimator-model --tag_set serve \
--signature_def serving_default --input_exprs input=[10]
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Result for output key output: [12.]
TensorFlow 2: บันทึกและส่งออก SavedModel
บันทึกและส่งออก SavedModel ที่กำหนดด้วย tf.Module
ในการส่งออกโมเดลของคุณใน TensorFlow 2 คุณต้องกำหนด tf.Module
หรือ tf.keras.Model
เพื่อเก็บตัวแปรและฟังก์ชันทั้งหมดของโมเดลของคุณ จากนั้น คุณสามารถเรียก tf.saved_model.save
เพื่อสร้าง SavedModel อ้างถึง การ บันทึกโมเดลแบบกำหนดเอง ในคู่มือ การใช้รูปแบบ SavedModel เพื่อเรียนรู้เพิ่มเติม
class MyModel(tf.Module):
@tf.function
def __call__(self, input):
return add_two(input)
model = MyModel()
@tf.function
def serving_default(input):
return {'output': model(input)}
signature_function = serving_default.get_concrete_function(
tf.TensorSpec(shape=[], dtype=tf.float32))
tf.saved_model.save(
model, 'tf2-save', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function})
INFO:tensorflow:Assets written to: tf2-save/assets 2021-11-06 01:28:53.105391: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
!saved_model_cli run --dir tf2-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Result for output key output: 12.0
บันทึกและส่งออก SavedModel ที่กำหนดด้วย Keras
Keras API สำหรับการบันทึกและส่งออก— Mode.save
หรือ tf.keras.models.save_model
— สามารถส่งออก SavedModel จาก tf.keras.Model
ตรวจสอบ บันทึกและโหลดโมเดล Keras สำหรับรายละเอียดเพิ่มเติม
inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
return {'output': model(input)}
model.save('keras-model', save_format='tf', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default})
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (). INFO:tensorflow:Assets written to: keras-model/assets
!saved_model_cli run --dir keras-model --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from keras-model/variables/variables Result for output key output: 12.0
กำลังโหลด SavedModel
SavedModel ที่บันทึกด้วย API ใดๆ ข้างต้นสามารถโหลดได้โดยใช้ TensorFlow 1 หรือ TensorFlow API
TensorFlow 1 SavedModel โดยทั่วไปสามารถใช้สำหรับการอนุมานเมื่อโหลดเข้าสู่ TensorFlow 2 แต่การฝึกอบรม (การสร้างการไล่ระดับสี) เป็นไปได้ก็ต่อเมื่อ SavedModel มี ตัวแปรทรัพยากร คุณสามารถตรวจสอบ dtype ของตัวแปรได้—หากตัวแปร dtype มี "_ref" แสดงว่าเป็นตัวแปรอ้างอิง
TensorFlow 2 SavedModel สามารถโหลดและเรียกใช้จาก TensorFlow 1 ได้ ตราบใดที่ SavedModel ถูกบันทึกด้วยลายเซ็น
ส่วนด้านล่างประกอบด้วยตัวอย่างโค้ดที่แสดงวิธีการโหลด SavedModels ที่บันทึกไว้ในส่วนก่อนหน้า และเรียกลายเซ็นที่ส่งออก
TensorFlow 1: โหลด SavedModel ด้วย tf.saved_model.load
ใน TensorFlow 1 คุณสามารถนำเข้า SavedModel ลงในกราฟและเซสชันปัจจุบันได้โดยตรงโดยใช้ tf.saved_model.load
คุณสามารถเรียก Session.run
กับชื่ออินพุตและเอาต์พุตเทนเซอร์:
def load_tf1(path, input):
print('Loading from', path)
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
meta_graph = tf1.saved_model.load(sess, ["serve"], path)
sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_name = sig_def.inputs['input'].name
output_name = sig_def.outputs['output'].name
print(' Output with input', input, ': ',
sess.run(output_name, feed_dict={input_name: input}))
load_tf1('saved-model-builder', 5.)
load_tf1('simple-save', 5.)
load_tf1('estimator-model', [5.]) # Estimator's input must be batched.
load_tf1('tf2-save', 5.)
load_tf1('keras-model', 5.)
Loading from saved-model-builder WARNING:tensorflow:From /tmp/ipykernel_26511/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from estimator-model INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Output with input [5.0] : [7.] Loading from tf2-save INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Output with input 5.0 : 7.0 Loading from keras-model INFO:tensorflow:Restoring parameters from keras-model/variables/variables Output with input 5.0 : 7.0ตัวยึดตำแหน่ง23
TensorFlow 2: โหลดโมเดลที่บันทึกด้วย tf.saved_model
ใน TensorFlow 2 ออบเจ็กต์จะถูกโหลดลงในอ็อบเจ็กต์ Python ที่เก็บตัวแปรและฟังก์ชัน ใช้งานได้กับรุ่นที่บันทึกไว้จาก TensorFlow 1
ตรวจสอบเอกสาร tf.saved_model.load
API และ การโหลดและการใช้โมเดลแบบกำหนดเอง จากคู่มือ รูปแบบการใช้ SavedModel เพื่อดูรายละเอียด
def load_tf2(path, input):
print('Loading from', path)
loaded = tf.saved_model.load(path)
out = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
tf.constant(input))['output']
print(' Output with input', input, ': ', out)
load_tf2('saved-model-builder', 5.)
load_tf2('simple-save', 5.)
load_tf2('estimator-model', [5.]) # Estimator's input must be batched.
load_tf2('tf2-save', 5.)
load_tf2('keras-model', 5.)
Loading from saved-model-builder INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from estimator-model Output with input [5.0] : tf.Tensor([7.], shape=(1,), dtype=float32) Loading from tf2-save Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from keras-model Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32)
โมเดลที่บันทึกด้วย TensorFlow 2 API ยังสามารถเข้าถึง tf.function
และตัวแปรที่แนบมากับโมเดล (แทนที่จะส่งออกเป็นลายเซ็น) ตัวอย่างเช่น:
loaded = tf.saved_model.load('tf2-save')
print('restored __call__:', loaded.__call__)
print('output with input 5.', loaded(5))
restored __call__: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7f30cc940990> output with input 5. tf.Tensor(7.0, shape=(), dtype=float32)
TensorFlow 2: โหลดโมเดลที่บันทึกด้วย Keras
Keras กำลังโหลด API— tf.keras.models.load_model
อนุญาตให้คุณรีโหลดโมเดลที่บันทึกไว้กลับเข้าไปในอ็อบเจกต์ Keras Model โปรดทราบว่าวิธีนี้อนุญาตให้คุณโหลด SavedModels ที่บันทึกด้วย Keras เท่านั้น ( Model.save
หรือ tf.keras.models.save_model
)
โมเดลที่บันทึกด้วย tf.saved_model.save
ควรโหลดด้วย tf.saved_model.load
คุณสามารถโหลดโมเดล Keras ที่บันทึกด้วย Model.save
โดยใช้ tf.saved_model.load
แต่คุณจะได้เฉพาะกราฟ TensorFlow โปรดดูเอกสาร tf.keras.models.load_model
API และ บันทึกและโหลดคู่มือ Keras models สำหรับรายละเอียด
loaded_model = tf.keras.models.load_model('keras-model')
loaded_model.predict_on_batch(tf.constant([1, 3, 4]))
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (3,). array([3., 5., 6.], dtype=float32)
GraphDef และ MetaGraphDef
ไม่มีวิธีที่ตรงไปตรงมาในการโหลด Raw GraphDef
หรือ MetaGraphDef
ไปยัง TF2 อย่างไรก็ตาม คุณสามารถแปลงรหัส TF1 ที่นำเข้ากราฟเป็น TF2 concrete_function
โดยใช้ v1.wrap_function
ขั้นแรก บันทึก MetaGraphDef:
# Save a simple multiplication computation:
with tf.Graph().as_default() as g:
x = tf1.placeholder(tf.float32, shape=[], name='x')
v = tf.Variable(3.0, name='v')
y = tf.multiply(x, v, name='y')
with tf1.Session() as sess:
sess.run(v.initializer)
print(sess.run(y, feed_dict={x: 5}))
s = tf1.train.Saver()
s.export_meta_graph('multiply.pb', as_text=True)
s.save(sess, 'multiply_values.ckpt')
15.0
เมื่อใช้ TF1 API คุณสามารถใช้ tf1.train.import_meta_graph
เพื่อนำเข้ากราฟและคืนค่า:
with tf.Graph().as_default() as g:
meta = tf1.train.import_meta_graph('multiply.pb')
x = g.get_tensor_by_name('x:0')
y = g.get_tensor_by_name('y:0')
with tf1.Session() as sess:
meta.restore(sess, 'multiply_values.ckpt')
print(sess.run(y, feed_dict={x: 5}))
INFO:tensorflow:Restoring parameters from multiply_values.ckpt 15.0ตัวยึดตำแหน่ง33
ไม่มี TF2 API สำหรับการโหลดกราฟ แต่คุณยังสามารถนำเข้าไปยังฟังก์ชันที่เป็นรูปธรรมที่สามารถดำเนินการในโหมดกระตือรือร้นได้:
def import_multiply():
# Any graph-building code is allowed here.
tf1.train.import_meta_graph('multiply.pb')
# Creates a tf.function with all the imported elements in the function graph.
wrapped_import = tf1.wrap_function(import_multiply, [])
import_graph = wrapped_import.graph
x = import_graph.get_tensor_by_name('x:0')
y = import_graph.get_tensor_by_name('y:0')
# Restore the variable values.
tf1.train.Saver(wrapped_import.variables).restore(
sess=None, save_path='multiply_values.ckpt')
# Create a concrete function by pruning the wrap_function (similar to sess.run).
multiply_fn = wrapped_import.prune(feeds=x, fetches=y)
# Run this function
multiply_fn(tf.constant(5.)) # inputs to concrete functions must be Tensors.
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from multiply_values.ckpt <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
การเปลี่ยนแปลงจาก TensorFlow 1 เป็น TensorFlow 2
ส่วนนี้แสดงรายการเงื่อนไขการบันทึกคีย์และการโหลดจาก TensorFlow 1, TensorFlow 2 ที่เทียบเท่า และสิ่งที่เปลี่ยนแปลงไป
รูปแบบที่บันทึกไว้
SavedModel เป็นรูปแบบที่เก็บโปรแกรม TensorFlow ที่สมบูรณ์พร้อมพารามิเตอร์และการคำนวณ ประกอบด้วยลายเซ็นที่ใช้โดยแพลตฟอร์มที่ให้บริการเพื่อเรียกใช้โมเดล
รูปแบบไฟล์เองไม่ได้เปลี่ยนแปลงอย่างมีนัยสำคัญ ดังนั้น SavedModels สามารถโหลดและให้บริการโดยใช้ TensorFlow 1 หรือ TensorFlow 2 APIs
ความแตกต่างระหว่าง TensorFlow 1 และ TensorFlow 2
กรณีการใช้งานการ เสิร์ฟ และการ อนุมาน ยังไม่ได้รับการอัปเดตใน TensorFlow 2 นอกเหนือจากการเปลี่ยนแปลง API— ได้มีการแนะนำการปรับปรุงในความสามารถใน การนำกลับมาใช้ใหม่ และ เขียนแบบจำลอง ที่โหลดจาก SavedModel
ใน TensorFlow 2 โปรแกรมจะแสดงด้วยอ็อบเจ็กต์ เช่น tf.Variable
, tf.Module
หรือ Keras ระดับสูงกว่า ( tf.keras.Model
) และเลเยอร์ ( tf.keras.layers
) ไม่มีตัวแปรส่วนกลางอีกต่อไปที่มีค่าที่เก็บไว้ในเซสชัน และขณะนี้กราฟมีอยู่ใน tf.function
ที่แตกต่างกัน ดังนั้น ในระหว่างการเอ็กซ์พอร์ตโมเดล SavedModel จะบันทึกแต่ละส่วนประกอบและกราฟฟังก์ชันแยกกัน
เมื่อคุณเขียนโปรแกรม TensorFlow ด้วย TensorFlow Python APIs คุณต้องสร้างวัตถุเพื่อจัดการตัวแปร ฟังก์ชัน และทรัพยากรอื่นๆ โดยทั่วไป สามารถทำได้โดยใช้ Keras API แต่คุณสามารถสร้างอ็อบเจ็กต์ได้ด้วยการสร้างหรือจัดคลาสย่อย tf.Module
โมเดล Keras ( tf.keras.Model
) และ tf.Module
จะติดตามตัวแปรและฟังก์ชันที่แนบมากับตัวแปรเหล่านี้โดยอัตโนมัติ SavedModel จะบันทึกการเชื่อมต่อเหล่านี้ระหว่างโมดูล ตัวแปร และฟังก์ชัน เพื่อให้สามารถกู้คืนได้เมื่อโหลด
ลายเซ็น
ลายเซ็นคือจุดสิ้นสุดของ SavedModel—ซึ่งบอกผู้ใช้ถึงวิธีเรียกใช้โมเดลและอินพุตที่จำเป็น
ใน TensorFlow 1 ลายเซ็นจะถูกสร้างขึ้นโดยการแสดงรายการเทนเซอร์อินพุตและเอาต์พุต ใน TensorFlow 2 ลายเซ็นจะถูกสร้างขึ้นโดยการส่งผ่าน ฟังก์ชันที่เป็นรูปธรรม (อ่านเพิ่มเติมเกี่ยวกับฟังก์ชัน TensorFlow ได้ใน บทนำเกี่ยวกับกราฟและคู่มือ tf.function ) กล่าวโดยย่อ ฟังก์ชันที่เป็นรูปธรรมจะถูกสร้างขึ้น จาก tf.function
:
# Option 1: Specify an input signature.
@tf.function(input_signature=[...])
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn
})
# Option 2: Call `get_concrete_function`
@tf.function
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn.get_concrete_function(...)
})
Session.run
ใน TensorFlow 1 คุณสามารถเรียก Session.run
ด้วยกราฟที่นำเข้า ตราบใดที่คุณทราบชื่อเทนเซอร์อยู่แล้ว วิธีนี้ทำให้คุณสามารถดึงค่าตัวแปรที่กู้คืน หรือเรียกใช้ส่วนต่างๆ ของโมเดลที่ไม่ได้ส่งออกในลายเซ็น
ใน TensorFlow 2 คุณสามารถเข้าถึงตัวแปรได้โดยตรง เช่น เมทริกซ์น้ำหนัก ( kernel
):
model = tf.Module()
model.dense_layer = tf.keras.layers.Dense(...)
tf.saved_model.save('my_saved_model')
loaded = tf.saved_model.load('my_saved_model')
loaded.dense_layer.kernel
หรือเรียกใช้ tf.function
ที่แนบมากับโมเดลวัตถุ เช่น loaded.__call__
ต่างจาก TF1 ตรงที่ไม่มีวิธีการแยกส่วนต่าง ๆ ของฟังก์ชันและเข้าถึงค่ากลาง คุณ ต้อง ส่งออกการทำงานที่จำเป็นทั้งหมดในวัตถุที่บันทึกไว้
บันทึกการย้ายข้อมูลการให้บริการ TensorFlow
SavedModel ถูกสร้างขึ้นเพื่อทำงานกับ TensorFlow Serving แพลตฟอร์มนี้นำเสนอคำขอการคาดการณ์ประเภทต่างๆ: จำแนก ถอยหลัง และคาดการณ์
TensorFlow 1 API ช่วยให้คุณสร้างลายเซ็นประเภทเหล่านี้ด้วย utils:
-
tf.compat.v1.saved_model.classification_signature_def
-
tf.compat.v1.saved_model.regression_signature_def
-
tf.compat.v1.saved_model.predict_signature_def
การจัดประเภท (การ classification_signature_def
_signature_def) และ การถดถอย ( regression_signature_def
) จำกัดอินพุตและเอาต์พุต ดังนั้นอินพุตต้องเป็น tf.Example
และเอาต์พุตต้องเป็น classes
scores
หรือ prediction
ในขณะเดียวกัน ลายเซ็นทำนาย ( predict_signature_def
) ไม่มีข้อจำกัด
SavedModels ที่ส่งออกด้วย TensorFlow 2 API เข้ากันได้กับ TensorFlow Serving แต่จะมีเฉพาะลายเซ็นการคาดการณ์เท่านั้น ลายเซ็นการจัดประเภทและการถดถอยถูกลบออก
หากคุณต้องการใช้ลายเซ็นการจัดประเภทและการถดถอย คุณสามารถแก้ไข SavedModel ที่ส่งออกได้โดยใช้ tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater
ขั้นตอนถัดไป
หากต้องการเรียนรู้เพิ่มเติมเกี่ยวกับ SavedModels ใน TensorFlow 2 โปรดดูคำแนะนำต่อไปนี้:
หากคุณกำลังใช้ TensorFlow Hub คุณอาจพบว่าคำแนะนำเหล่านี้มีประโยชน์: