TensorFlow.org에서 보기 | Google Colab에서 실행 | GitHub에서 소스 보기 | 노트북 다운로드 |
TensorFlow 1의 그래프 및 세션에서 TensorFlow 2 API(예: tf.function
, tf.Module
및 tf.keras.Model
)로 모델을 마이그레이션한 후에는 모델 저장 및 로드 코드를 마이그레이션할 수 있습니다. 이 노트북은 TensorFlow 1 및 TensorFlow 2에서 SavedModel 형식으로 저장하고 로드하는 방법에 대한 예를 제공합니다. 다음은 TensorFlow 1에서 TensorFlow 2로 마이그레이션하기 위한 관련 API 변경 사항에 대한 간략한 개요입니다.
텐서플로우 1 | TensorFlow 2로 마이그레이션 | |
---|---|---|
절약 | tf.compat.v1.saved_model.Builder tf.compat.v1.saved_model.simple_save | tf.saved_model.save 케라스: tf.keras.models.save_model |
로딩 중 | tf.compat.v1.saved_model.load | tf.saved_model.load 케라스: tf.keras.models.load_model |
서명 : 입력 집합 그리고 출력 텐서는 실행하는 데 사용할 수 있습니다. | *.signature_def 유틸리티를 사용하여 생성됨(예: tf.compat.v1.saved_model.predict_signature_def ) | tf.function 을 작성하고 signatures 인수를 사용하여 내보냅니다.tf.saved_model.save 에 있습니다. |
분류 및 회귀 : 특별한 유형의 서명 | 생성tf.compat.v1.saved_model.classification_signature_def ,tf.compat.v1.saved_model.regression_signature_def ,및 특정 Estimator 내보내기. | 이 두 가지 서명 유형은 TensorFlow 2에서 제거되었습니다. 서빙 라이브러리에 이러한 메서드 이름이 필요한 경우 tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater . |
매핑에 대한 더 자세한 설명은 아래 TensorFlow 1에서 TensorFlow 2로의 변경 사항 섹션을 참조하세요.
설정
아래 예제는 TensorFlow 1 및 TensorFlow 2 API를 사용하여 동일한 더미 TensorFlow 모델(아래 add_two
로 정의됨)을 SavedModel 형식으로 내보내고 로드하는 방법을 보여줍니다. 가져오기 및 유틸리티 기능을 설정하여 시작합니다.
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import shutil
def remove_dir(path):
try:
shutil.rmtree(path)
except:
pass
def add_two(input):
return input + 2
TensorFlow 1: 저장된 모델 저장 및 내보내기
TensorFlow 1에서는 tf.compat.v1.saved_model.Builder
, tf.compat.v1.saved_model.simple_save
및 tf.estimator.Estimator.export_saved_model
API를 사용하여 TensorFlow 그래프 및 세션을 빌드, 저장 및 내보내기합니다.
1. SavedModelBuilder를 사용하여 그래프를 저장된 모델로 저장
remove_dir("saved-model-builder")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add two output: ", sess.run(output, {input: 3.}))
# Save with SavedModelBuilder
builder = tf1.saved_model.Builder('saved-model-builder')
sig_def = tf1.saved_model.predict_signature_def(
inputs={'input': input},
outputs={'output': output})
builder.add_meta_graph_and_variables(
sess, tags=["serve"], signature_def_map={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: sig_def
})
builder.save()
add two output: 5.0 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:208: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info. INFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: saved-model-builder/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
Traceback (most recent call last): File "/tmpfs/src/tf_docs_env/bin/saved_model_cli", line 8, in <module> sys.exit(main()) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 1211, in main args.func(args) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 769, in run init_tpu=args.init_tpu, tf_debug=args.tf_debug) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py", line 417, in run_saved_model_with_feed_dict tag_set) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_utils.py", line 117, in get_meta_graph_def saved_model = read_saved_model(saved_model_dir) File "/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_utils.py", line 55, in read_saved_model raise IOError("SavedModel file does not exist at: %s" % saved_model_dir) OSError: SavedModel file does not exist at: simple-save
2. 제공을 위한 저장된 모델 빌드
remove_dir("simple-save")
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=[])
output = add_two(input)
print("add_two output: ", sess.run(output, {input: 3.}))
tf1.saved_model.simple_save(
sess, 'simple-save',
inputs={'input': input},
outputs={'output': output})
add_two output: 5.0 WARNING:tensorflow:From /tmp/ipykernel_26511/250978412.py:12: simple_save (from tensorflow.python.saved_model.simple_save) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.simple_save. INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: simple-save/saved_model.pb
!saved_model_cli run --dir simple-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Result for output key output: 12.0
3. Estimator 추론 그래프를 저장된 모델로 내보내기
Estimator model_fn
(아래 정의)의 정의에서 tf.estimator.EstimatorSpec
에서 export_outputs
를 반환하여 모델에 서명을 정의할 수 있습니다. 다양한 유형의 출력이 있습니다.
-
tf.estimator.export.ClassificationOutput
-
tf.estimator.export.RegressionOutput
-
tf.estimator.export.PredictOutput
이들은 각각 분류, 회귀 및 예측 서명 유형을 생성합니다.
추정기를 tf.estimator.Estimator.export_saved_model
과 함께 내보내면 이러한 서명이 모델과 함께 저장됩니다.
def model_fn(features, labels, mode):
output = add_two(features['input'])
step = tf1.train.get_global_step()
return tf.estimator.EstimatorSpec(
mode,
predictions=output,
train_op=step.assign_add(1),
loss=tf.constant(0.),
export_outputs={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: \
tf.estimator.export.PredictOutput({'output': output})})
est = tf.estimator.Estimator(model_fn, 'estimator-checkpoints')
# Train for one step to create a checkpoint.
def train_fn():
return tf.data.Dataset.from_tensors({'input': 3.})
est.train(train_fn, steps=1)
# This utility function `build_raw_serving_input_receiver_fn` takes in raw
# tensor features and builds an "input serving receiver function", which
# creates placeholder inputs to the model.
serving_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
{'input': tf.constant(3.)}) # Pass in a dummy input batch.
estimator_path = est.export_saved_model('exported-estimator', serving_input_fn)
# Estimator's export_saved_model creates a time stamped directory. Move this
# to a set path so it can be inspected with `saved_model_cli` in the cell below.
!rm -rf estimator-model
import shutil
shutil.move(estimator_path, 'estimator-model')
INFO:tensorflow:Using default config. INFO:tensorflow:Using config: {'_model_dir': 'estimator-checkpoints', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:401: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.0, step = 1 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 1... INFO:tensorflow:Saving checkpoints for 1 into estimator-checkpoints/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 1... INFO:tensorflow:Loss for final step: 0.0. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Signatures INCLUDED in export for Classify: None INFO:tensorflow:Signatures INCLUDED in export for Regress: None INFO:tensorflow:Signatures INCLUDED in export for Predict: ['serving_default'] INFO:tensorflow:Signatures INCLUDED in export for Train: None INFO:tensorflow:Signatures INCLUDED in export for Eval: None INFO:tensorflow:Restoring parameters from estimator-checkpoints/model.ckpt-1 INFO:tensorflow:Assets added to graph. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: exported-estimator/temp-1636162129/saved_model.pb 'estimator-model'
!saved_model_cli run --dir estimator-model --tag_set serve \
--signature_def serving_default --input_exprs input=[10]
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Result for output key output: [12.]
TensorFlow 2: 저장된 모델 저장 및 내보내기
tf.Module로 정의된 SavedModel 저장 및 내보내기
TensorFlow 2에서 모델을 내보내려면 모델의 모든 변수와 함수를 보유할 tf.Module
또는 tf.keras.Model
을 정의해야 합니다. 그런 다음 tf.saved_model.save
를 호출하여 저장된 모델을 생성할 수 있습니다. 자세한 내용은 저장된 모델 형식 사용 가이드에서 사용자 지정 모델 저장 을 참조하세요.
class MyModel(tf.Module):
@tf.function
def __call__(self, input):
return add_two(input)
model = MyModel()
@tf.function
def serving_default(input):
return {'output': model(input)}
signature_function = serving_default.get_concrete_function(
tf.TensorSpec(shape=[], dtype=tf.float32))
tf.saved_model.save(
model, 'tf2-save', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_function})
INFO:tensorflow:Assets written to: tf2-save/assets 2021-11-06 01:28:53.105391: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
!saved_model_cli run --dir tf2-save --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Result for output key output: 12.0
Keras로 정의된 SavedModel 저장 및 내보내기
Mode.save
또는 tf.keras.models.save_model
저장 및 내보내기용 Keras API는 tf.keras.Model
에서 저장된 모델을 내보낼 수 있습니다. 자세한 내용은 Keras 모델 저장 및 로드 를 확인하세요.
inp = tf.keras.Input(3)
out = add_two(inp)
model = tf.keras.Model(inputs=inp, outputs=out)
@tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.float32)])
def serving_default(input):
return {'output': model(input)}
model.save('keras-model', save_format='tf', signatures={
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY: serving_default})
WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (). INFO:tensorflow:Assets written to: keras-model/assets
!saved_model_cli run --dir keras-model --tag_set serve \
--signature_def serving_default --input_exprs input=10
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/tools/saved_model_cli.py:453: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Restoring parameters from keras-model/variables/variables Result for output key output: 12.0
저장된 모델 로드
위 API 중 하나로 저장된 SavedModel은 TensorFlow 1 또는 TensorFlow API를 사용하여 로드할 수 있습니다.
TensorFlow 1 SavedModel은 일반적으로 TensorFlow 2에 로드될 때 추론에 사용할 수 있지만 교육(그라디언트 생성)은 SavedModel에 리소스 변수 가 포함된 경우에만 가능합니다. 변수의 dtype을 확인할 수 있습니다. 변수 dtype에 "_ref"가 포함되어 있으면 참조 변수입니다.
TensorFlow 2 SavedModel은 SavedModel이 서명과 함께 저장되어 있는 한 TensorFlow 1에서 로드 및 실행할 수 있습니다.
아래 섹션에는 이전 섹션에서 저장한 저장된 모델을 로드하고 내보낸 서명을 호출하는 방법을 보여주는 코드 샘플이 포함되어 있습니다.
TensorFlow 1: tf.saved_model.load로 저장된 모델 로드
TensorFlow 1에서는 tf.saved_model.load
를 사용하여 저장된 모델을 현재 그래프와 세션으로 직접 가져올 수 있습니다. 텐서 입력 및 출력 이름에서 Session.run
을 호출할 수 있습니다.
def load_tf1(path, input):
print('Loading from', path)
with tf.Graph().as_default() as g:
with tf1.Session() as sess:
meta_graph = tf1.saved_model.load(sess, ["serve"], path)
sig_def = meta_graph.signature_def[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
input_name = sig_def.inputs['input'].name
output_name = sig_def.outputs['output'].name
print(' Output with input', input, ': ',
sess.run(output_name, feed_dict={input_name: input}))
load_tf1('saved-model-builder', 5.)
load_tf1('simple-save', 5.)
load_tf1('estimator-model', [5.]) # Estimator's input must be batched.
load_tf1('tf2-save', 5.)
load_tf1('keras-model', 5.)
Loading from saved-model-builder WARNING:tensorflow:From /tmp/ipykernel_26511/1548963983.py:5: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version. Instructions for updating: This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0. INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore INFO:tensorflow:The specified SavedModel has no variables; no checkpoints were restored. Output with input 5.0 : 7.0 Loading from estimator-model INFO:tensorflow:Restoring parameters from estimator-model/variables/variables Output with input [5.0] : [7.] Loading from tf2-save INFO:tensorflow:Restoring parameters from tf2-save/variables/variables Output with input 5.0 : 7.0 Loading from keras-model INFO:tensorflow:Restoring parameters from keras-model/variables/variables Output with input 5.0 : 7.0
TensorFlow 2: tf.saved_model로 저장된 모델 로드
TensorFlow 2에서 객체는 변수와 함수를 저장하는 Python 객체에 로드됩니다. 이것은 TensorFlow 1에서 저장된 모델과 호환됩니다.
자세한 내용은 tf.saved_model.load
API 문서와 저장된 모델 형식 사용 가이드 의 사용자 지정 모델 로드 및 사용 섹션을 확인하세요.
def load_tf2(path, input):
print('Loading from', path)
loaded = tf.saved_model.load(path)
out = loaded.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY](
tf.constant(input))['output']
print(' Output with input', input, ': ', out)
load_tf2('saved-model-builder', 5.)
load_tf2('simple-save', 5.)
load_tf2('estimator-model', [5.]) # Estimator's input must be batched.
load_tf2('tf2-save', 5.)
load_tf2('keras-model', 5.)
Loading from saved-model-builder INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from simple-save INFO:tensorflow:Saver not created because there are no variables in the graph to restore Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from estimator-model Output with input [5.0] : tf.Tensor([7.], shape=(1,), dtype=float32) Loading from tf2-save Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32) Loading from keras-model Output with input 5.0 : tf.Tensor(7.0, shape=(), dtype=float32)
TensorFlow 2 API로 저장된 모델은 tf.function
및 모델에 첨부된 변수에 액세스할 수도 있습니다(시그니처로 내보낸 모델 대신). 예를 들어:
loaded = tf.saved_model.load('tf2-save')
print('restored __call__:', loaded.__call__)
print('output with input 5.', loaded(5))
restored __call__: <tensorflow.python.saved_model.function_deserialization.RestoredFunction object at 0x7f30cc940990> output with input 5. tf.Tensor(7.0, shape=(), dtype=float32)
TensorFlow 2: Keras로 저장된 모델 로드
Keras 로딩 API tf.keras.models.load_model
를 사용하면 저장된 모델을 다시 Keras 모델 객체로 다시 로드할 수 있습니다. 이것은 Keras( Model.save
또는 tf.keras.models.save_model
)로 저장된 저장된 모델만 로드할 수 있다는 점에 유의하십시오.
tf.saved_model.load
로 저장된 모델은 tf.saved_model.save
로 로드해야 합니다. tf.saved_model.load를 사용하여 tf.saved_model.load
로 저장된 Model.save
모델을 로드할 수 있지만 TensorFlow 그래프만 얻을 수 있습니다. 자세한 내용은 tf.keras.models.load_model
API 문서 및 Keras 모델 저장 및 로드 가이드를 참조하세요.
loaded_model = tf.keras.models.load_model('keras-model')
loaded_model.predict_on_batch(tf.constant([1, 3, 4]))
WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually. WARNING:tensorflow:Model was constructed with shape (None, 3) for input KerasTensor(type_spec=TensorSpec(shape=(None, 3), dtype=tf.float32, name='input_1'), name='input_1', description="created by layer 'input_1'"), but it was called on an input with incompatible shape (3,). array([3., 5., 6.], dtype=float32)
GraphDef 및 MetaGraphDef
원시 GraphDef
또는 MetaGraphDef
를 TF2로 로드하는 간단한 방법은 없습니다. 그러나 v1.wrap_function
을 사용하여 그래프를 TF2 concrete_function
으로 가져오는 TF1 코드를 변환할 수 있습니다.
먼저 MetaGraphDef를 저장합니다.
# Save a simple multiplication computation:
with tf.Graph().as_default() as g:
x = tf1.placeholder(tf.float32, shape=[], name='x')
v = tf.Variable(3.0, name='v')
y = tf.multiply(x, v, name='y')
with tf1.Session() as sess:
sess.run(v.initializer)
print(sess.run(y, feed_dict={x: 5}))
s = tf1.train.Saver()
s.export_meta_graph('multiply.pb', as_text=True)
s.save(sess, 'multiply_values.ckpt')
15.0
TF1 API를 사용하면 tf1.train.import_meta_graph
를 사용하여 그래프를 가져오고 값을 복원할 수 있습니다.
with tf.Graph().as_default() as g:
meta = tf1.train.import_meta_graph('multiply.pb')
x = g.get_tensor_by_name('x:0')
y = g.get_tensor_by_name('y:0')
with tf1.Session() as sess:
meta.restore(sess, 'multiply_values.ckpt')
print(sess.run(y, feed_dict={x: 5}))
INFO:tensorflow:Restoring parameters from multiply_values.ckpt 15.0
그래프를 로드하기 위한 TF2 API는 없지만 열망 모드에서 실행할 수 있는 구체적인 함수로 가져올 수 있습니다.
def import_multiply():
# Any graph-building code is allowed here.
tf1.train.import_meta_graph('multiply.pb')
# Creates a tf.function with all the imported elements in the function graph.
wrapped_import = tf1.wrap_function(import_multiply, [])
import_graph = wrapped_import.graph
x = import_graph.get_tensor_by_name('x:0')
y = import_graph.get_tensor_by_name('y:0')
# Restore the variable values.
tf1.train.Saver(wrapped_import.variables).restore(
sess=None, save_path='multiply_values.ckpt')
# Create a concrete function by pruning the wrap_function (similar to sess.run).
multiply_fn = wrapped_import.prune(feeds=x, fetches=y)
# Run this function
multiply_fn(tf.constant(5.)) # inputs to concrete functions must be Tensors.
WARNING:tensorflow:Saver is deprecated, please switch to tf.train.Checkpoint or tf.keras.Model.save_weights for training checkpoints. When executing eagerly variables do not necessarily have unique names, and so the variable.name-based lookups Saver performs are error-prone. INFO:tensorflow:Restoring parameters from multiply_values.ckpt <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
TensorFlow 1에서 TensorFlow 2로의 변경 사항
이 섹션에서는 TensorFlow 1의 주요 저장 및 로드 용어, 이에 상응하는 TensorFlow 2 및 변경된 사항을 나열합니다.
저장된 모델
SavedModel 은 매개변수 및 계산이 포함된 완전한 TensorFlow 프로그램을 저장하는 형식입니다. 여기에는 모델을 실행하기 위해 플랫폼을 제공하는 데 사용되는 서명이 포함됩니다.
파일 형식 자체는 크게 변경되지 않았으므로 TensorFlow 1 또는 TensorFlow 2 API를 사용하여 SavedModel을 로드하고 제공할 수 있습니다.
TensorFlow 1과 TensorFlow 2의 차이점
제공 및 추론 사용 사례는 API 변경을 제외하고 TensorFlow 2에서 업데이트되지 않았습니다. SavedModel에서 로드된 모델 을 재사용 하고 구성하는 기능이 개선되었습니다.
TensorFlow 2에서 프로그램은 tf.Variable
, tf.Module
또는 상위 수준 Keras 모델( tf.keras.Model
) 및 레이어( tf.keras.layers
)와 같은 객체로 표현됩니다. 세션에 값이 저장된 전역 변수는 더 이상 없으며 그래프는 이제 다른 tf.function
에 존재합니다. 따라서 모델을 내보내는 동안 SavedModel은 각 구성 요소와 함수 그래프를 별도로 저장합니다.
TensorFlow Python API로 TensorFlow 프로그램을 작성할 때 변수, 함수 및 기타 리소스를 관리하기 위한 객체를 빌드해야 합니다. 일반적으로 이는 Keras API를 사용하여 수행되지만 tf.Module
을 생성하거나 서브클래싱하여 객체를 빌드할 수도 있습니다.
Keras 모델( tf.keras.Model
)과 tf.Module
자동으로 연결된 변수와 함수를 추적합니다. SavedModel은 모듈, 변수 및 함수 간의 이러한 연결을 저장하여 로드할 때 복원할 수 있습니다.
서명
서명은 저장된 모델의 끝점입니다. 서명은 사용자에게 모델을 실행하는 방법과 필요한 입력을 알려줍니다.
TensorFlow 1에서 서명은 입력 및 출력 텐서를 나열하여 생성됩니다. TensorFlow 2에서 서명은 구체적인 함수 를 전달하여 생성됩니다. ( 그래프 소개 및 tf.function 가이드에서 TensorFlow 함수에 대해 자세히 읽어보세요.) 간단히 말해서, 구체적인 함수는 tf.function
에서 생성됩니다.
# Option 1: Specify an input signature.
@tf.function(input_signature=[...])
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn
})
# Option 2: Call `get_concrete_function`
@tf.function
def fn(...):
...
return outputs
tf.saved_model.save(model, path, signatures={
'name': fn.get_concrete_function(...)
})
Session.run
TensorFlow 1에서는 이미 텐서 이름을 알고 있는 한 가져온 그래프로 Session.run
을 호출할 수 있습니다. 이를 통해 복원된 변수 값을 검색하거나 서명에서 내보내지 않은 모델 부분을 실행할 수 있습니다.
TensorFlow 2에서는 가중치 행렬( kernel
)과 같은 변수에 직접 액세스할 수 있습니다.
model = tf.Module()
model.dense_layer = tf.keras.layers.Dense(...)
tf.saved_model.save('my_saved_model')
loaded = tf.saved_model.load('my_saved_model')
loaded.dense_layer.kernel
또는 모델 객체에 첨부된 tf.function
s를 호출합니다(예: loaded.__call__
.
TF1과 달리 함수의 일부를 추출하고 중간 값에 액세스할 수 있는 방법이 없습니다. 저장된 개체에서 필요한 모든 기능을 내 보내야 합니다.
TensorFlow Serving 마이그레이션 참고사항
SavedModel은 원래 TensorFlow Serving 과 함께 작동하도록 만들어졌습니다. 이 플랫폼은 분류, 회귀 및 예측과 같은 다양한 유형의 예측 요청을 제공합니다.
TensorFlow 1 API를 사용하면 유틸리티를 사용하여 다음 유형의 서명을 만들 수 있습니다.
-
tf.compat.v1.saved_model.classification_signature_def
-
tf.compat.v1.saved_model.regression_signature_def
-
tf.compat.v1.saved_model.predict_signature_def
분류 ( classification_signature_def
) 및 회귀 ( regression_signature_def
)는 입력과 출력을 제한하므로 입력은 tf.Example
이어야 하고 출력은 classes
, scores
또는 prediction
이어야 합니다. 한편, 예측 서명 ( predict_signature_def
)에는 제한이 없습니다.
TensorFlow 2 API로 내보낸 저장된 모델은 TensorFlow Serving과 호환되지만 예측 서명만 포함합니다. 분류 및 회귀 서명이 제거되었습니다.
분류 및 회귀 서명을 사용해야 하는 경우 tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater를 사용하여 내보낸 tf.compat.v1.saved_model.signature_def_utils.MethodNameUpdater
을 수정할 수 있습니다.
다음 단계
TensorFlow 2의 저장된 모델에 대해 자세히 알아보려면 다음 가이드를 확인하세요.
TensorFlow Hub를 사용하는 경우 다음 가이드가 유용할 수 있습니다.