TensorFlow.org에서 보기 | Google Colab에서 실행하기 | GitHub에서 소스 보기 | 노트북 다운론드하기 |
경고: 추정기는 새 코드에 권장되지 않습니다. Estimator는
v1.Session
스타일 코드를 실행합니다. 이 코드는 올바르게 작성하기가 더 어렵고 특히 TF 2 코드와 결합될 때 예기치 않게 작동할 수 있습니다. 에스티메이터는 호환성 보장 이 적용되지만 보안 취약점 외에는 수정 사항이 제공되지 않습니다. 자세한 내용은 마이그레이션 가이드 를 참조하세요.
개요
TensorFlow Estimator는 TensorFlow에서 지원되며 신규 및 기존 tf.keras
모델에서 생성할 수 있습니다. 이 자습서에는 해당 프로세스의 완전하고 최소한의 예가 포함되어 있습니다.
주의: 케라스 모델을 사용한다면, 추정량을 변환하지 않고 tf.distribute
strategies과 함께 직접 사용할 수 있습니다. 따라서, model_to_estimator
s는 더 이상 권장되지 않습니다.
설정
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
2022-12-14 22:28:55.397019: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:28:55.397134: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:28:55.397144: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
간단한 케라스 모델 만들기
케라스에서는 여러 겹의 층을 쌓아 모델을 만들 수 있습니다. 일반적으로 모델은 층의 그래프로 구성됩니다. 이 중 가장 흔한 형태는 적층형 구조를 갖고 있는 tf.keras.Sequential
모델입니다.
간단한 완전히 연결 네트워크(다층 퍼셉트론)를 만들어봅시다:
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(3)
])
모델을 컴파일한 후, 모델 구조를 요약해 출력할 수 있습니다.
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam')
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 80 dropout (Dropout) (None, 16) 0 dense_1 (Dense) (None, 3) 51 ================================================================= Total params: 131 Trainable params: 131 Non-trainable params: 0 _________________________________________________________________
입력 함수 만들기
데이터셋 API를 사용해 대규모 데이터셋을 다루거나 여러 장치에서 훈련할 수 있습니다.
텐서플로 추정기는 입력 파이프라인(input pipeline)이 언제 어떻게 생성되었는지 제어해야 합니다. 이를 위해서는 "입력 함수", 즉 input_fn
이 필요합니다. 추정기는 이 함수를 별도의 매개변수 설정 없이 호출하게 됩니다. 이때 input_fn
은 tf.data.Dataset
객체를 반환해야 합니다.
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
input_fn
이 잘 구현되었는지 확인해봅니다.
for features_batch, labels_batch in input_fn().take(1):
print(features_batch)
print(labels_batch)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 {'dense_input': <tf.Tensor: shape=(32, 4), dtype=float32, numpy= array([[5.1, 3.4, 1.5, 0.2], [7.7, 3. , 6.1, 2.3], [5.7, 2.8, 4.5, 1.3], [6.8, 3.2, 5.9, 2.3], [5.2, 3.4, 1.4, 0.2], [5.6, 2.9, 3.6, 1.3], [5.5, 2.6, 4.4, 1.2], [5.5, 2.4, 3.7, 1. ], [4.6, 3.4, 1.4, 0.3], [7.7, 2.8, 6.7, 2. ], [7. , 3.2, 4.7, 1.4], [4.6, 3.2, 1.4, 0.2], [6.5, 3. , 5.2, 2. ], [5.5, 4.2, 1.4, 0.2], [5.4, 3.9, 1.3, 0.4], [5. , 3.5, 1.3, 0.3], [5.1, 3.8, 1.5, 0.3], [4.8, 3. , 1.4, 0.1], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [6.7, 3.3, 5.7, 2.1], [7.9, 3.8, 6.4, 2. ], [6.7, 3. , 5.2, 2.3], [5.8, 4. , 1.2, 0.2], [6.3, 2.5, 5. , 1.9], [5. , 3. , 1.6, 0.2], [6.9, 3.1, 5.1, 2.3], [6.1, 3. , 4.6, 1.4], [5.8, 2.7, 4.1, 1. ], [5.2, 2.7, 3.9, 1.4], [6.7, 3. , 5. , 1.7], [5.7, 2.6, 3.5, 1. ]], dtype=float32)>} tf.Tensor([0 2 1 2 0 1 1 1 0 2 1 0 2 0 0 0 0 0 2 2 2 2 2 0 2 0 2 1 1 1 1 1], shape=(32,), dtype=int64)
tf.keras.model을 추정기로 변환하기
tf.keras.model
은 tf.keras.estimator.model_to_estimator
함수를 이용해 tf.estimator.Estimator
객체로 변환함으로써 tf.estimator
API를 통해 훈련할 수 있습니다.
import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using the Keras model provided. INFO:tensorflow:Using the Keras model provided. WARNING:absl:You are using `tf.keras.optimizers.experimental.Optimizer` in TF estimator, which only supports `tf.keras.optimizers.legacy.Optimizer`. Automatically converting your optimizer to `tf.keras.optimizers.legacy.Optimizer`. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/backend.py:451: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn( INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpb42gnr_2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} 2022-12-14 22:29:02.013323: W tensorflow/c/c_api.cc:291] Operation '{name:'training/Adam/dense_1/bias/v/Assign' id:218 op device:{requested: '', assigned: ''} def:{ { {node training/Adam/dense_1/bias/v/Assign} } = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](training/Adam/dense_1/bias/v, training/Adam/dense_1/bias/v/Initializer/zeros)} }' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session. INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpb42gnr_2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
추정기를 훈련한 후 평가합니다.
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmpb42gnr_2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmpb42gnr_2/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmpb42gnr_2/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmpb42gnr_2/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 4 variables. INFO:tensorflow:Warm-started 4 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpb42gnr_2/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpb42gnr_2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 3.7705293, step = 0 INFO:tensorflow:loss = 3.7705293, step = 0 INFO:tensorflow:global_step/sec: 47.0665 INFO:tensorflow:global_step/sec: 47.0665 INFO:tensorflow:loss = 0.7411718, step = 100 (2.127 sec) INFO:tensorflow:loss = 0.7411718, step = 100 (2.127 sec) INFO:tensorflow:global_step/sec: 48.7542 INFO:tensorflow:global_step/sec: 48.7542 INFO:tensorflow:loss = 0.5705185, step = 200 (2.051 sec) INFO:tensorflow:loss = 0.5705185, step = 200 (2.051 sec) INFO:tensorflow:global_step/sec: 47.5396 INFO:tensorflow:global_step/sec: 47.5396 INFO:tensorflow:loss = 0.59224725, step = 300 (2.104 sec) INFO:tensorflow:loss = 0.59224725, step = 300 (2.104 sec) INFO:tensorflow:global_step/sec: 50.0255 INFO:tensorflow:global_step/sec: 50.0255 INFO:tensorflow:loss = 0.44595116, step = 400 (1.998 sec) INFO:tensorflow:loss = 0.44595116, step = 400 (1.998 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmpb42gnr_2/model.ckpt. INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmpb42gnr_2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Loss for final step: 0.62517214. INFO:tensorflow:Loss for final step: 0.62517214. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/training_v1.py:2333: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. updates = self.state_updates INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-14T22:29:14 INFO:tensorflow:Starting evaluation at 2022-12-14T22:29:14 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpb42gnr_2/model.ckpt-500 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpb42gnr_2/model.ckpt-500 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.46966s INFO:tensorflow:Inference Time : 0.46966s INFO:tensorflow:Finished evaluation at 2022-12-14-22:29:15 INFO:tensorflow:Finished evaluation at 2022-12-14-22:29:15 INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.3777611 INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.3777611 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmpb42gnr_2/model.ckpt-500 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmpb42gnr_2/model.ckpt-500 Eval result: {'loss': 0.3777611, 'global_step': 500}