TensorFlow.org에서 보기 | Google Colab에서 실행 | GitHub에서 소스 보기 | 노트북 다운로드 |
TensorFlow 1에서는 사용 tf.estimator.LoggingTensorHook
하면서, 텐서를 모니터링하고 기록하는 tf.estimator.StopAtStepHook
지정된 단계에서 정지 훈련을하는 데 도움이 때와 훈련 tf.estimator.Estimator
. 이 노트북은 사용자 정의 Keras 콜백 (사용 TensorFlow 2에서 그 등가물에 이러한 API에서 마이그레이션하는 방법을 보여줍니다 tf.keras.callbacks.Callback
포함) Model.fit
.
Keras 콜백 Model.fit
/ Model.evaluate
/ Model.predict
API에서 학습/평가/예측 중에 서로 다른 지점에서 호출되는 객체입니다. 콜백에 대한 자세한 내용은 tf.keras.callbacks.Callback
API 문서와 자체 콜백 작성 및 내장 메서드를 사용한 교육 및 평가 ( 콜백 사용 섹션) 가이드를 참조하세요. SessionRunHook
에서 TensorFlow 2의 Keras 콜백으로 마이그레이션하려면 지원 논리를 사용한 마이그레이션 교육 가이드를 확인하세요.
설정
데모용으로 가져오기 및 간단한 데이터세트로 시작합니다.
import tensorflow as tf
import tensorflow.compat.v1 as tf1
2022-12-14 20:42:39.819124: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:42:39.819214: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 20:42:39.819222: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
# Define an input function.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
TensorFlow 1: tf.estimator API를 사용하여 텐서를 기록하고 학습을 중지합니다.
TensorFlow 1에서는 훈련 동작을 제어하기 위해 다양한 후크를 정의합니다. 그런 다음 이 후크를 tf.estimator.EstimatorSpec
전달합니다.
아래 예에서:
- 텐서(예: 모델 가중치 또는 손실)를 모니터링/로그하려면
tf.estimator.LoggingTensorHook
(tf.train.LoggingTensorHook
은 별칭)을 사용합니다. - 특정 단계에서 훈련을 중지하려면
tf.estimator.StopAtStepHook
(tf.train.StopAtStepHook
은 별칭)을 사용합니다.
def _model_fn(features, labels, mode):
dense = tf1.layers.Dense(1)
logits = dense(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
# Define the stop hook.
stop_hook = tf1.train.StopAtStepHook(num_steps=2)
# Access tensors to be logged by names.
kernel_name = tf.identity(dense.weights[0])
bias_name = tf.identity(dense.weights[1])
logging_weight_hook = tf1.train.LoggingTensorHook(
tensors=[kernel_name, bias_name],
every_n_iter=1)
# Log the training loss by the tensor object.
logging_loss_hook = tf1.train.LoggingTensorHook(
{'loss from LoggingTensorHook': loss},
every_n_secs=3)
# Pass all hooks to `EstimatorSpec`.
return tf1.estimator.EstimatorSpec(mode,
loss=loss,
train_op=train_op,
training_hooks=[stop_hook,
logging_weight_hook,
logging_loss_hook])
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
estimator.train(_input_fn)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpmxcdczk2 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpmxcdczk2', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpmxcdczk2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 0.03374261, step = 0 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[ 0.5439496 ] [-0.04017198]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [0.] INFO:tensorflow:loss from LoggingTensorHook = 0.03374261 INFO:tensorflow:Tensor("Identity:0", shape=(2, 1), dtype=float32) = [[ 0.5060545 ] [-0.08353905]], Tensor("Identity_1:0", shape=(1,), dtype=float32) = [-0.03789507] (0.030 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 2... INFO:tensorflow:Saving checkpoints for 2 into /tmpfs/tmp/tmpmxcdczk2/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 2... INFO:tensorflow:Loss for final step: 0.07041928. <tensorflow_estimator.python.estimator.estimator.Estimator at 0x7f1a740a99d0>
TensorFlow 2: 사용자 지정 콜백 및 Model.fit을 사용하여 텐서를 기록하고 훈련을 중지합니다.
TensorFlow 2에서 Model.fit
(또는 Model.evaluate
tf.keras.callbacks.Callback
을 정의하여 텐서 모니터링 및 학습 중지를 구성할 수 있습니다. 그런 다음 이를 Model.fit
(또는 Model.evaluate
) callbacks
매개변수에 전달합니다. (자신만의 콜백 작성 가이드에서 자세히 알아보세요.)
아래 예에서:
StopAtStepHook
의 기능을 다시 생성하려면 특정 단계 수 후에 훈련을 중지on_batch_end
메서드를 재정의하는 사용자 지정 콜백(아래에서StopAtStepCallback
LoggingTensorHook
동작을 다시 생성하려면 이름으로 텐서에 액세스하는 것이 지원되지 않으므로 로깅된 텐서를 수동으로 기록하고 출력하는 사용자 지정 콜백(LoggingTensorCallback
사용자 정의 콜백 내에서 로깅 빈도를 구현할 수도 있습니다. 아래 예에서는 두 단계마다 가중치를 인쇄합니다. N초마다 기록하는 것과 같은 다른 전략도 가능합니다.
class StopAtStepCallback(tf.keras.callbacks.Callback):
def __init__(self, stop_step=None):
super().__init__()
self._stop_step = stop_step
def on_batch_end(self, batch, logs=None):
if self.model.optimizer.iterations >= self._stop_step:
self.model.stop_training = True
print('\nstop training now')
class LoggingTensorCallback(tf.keras.callbacks.Callback):
def __init__(self, every_n_iter):
super().__init__()
self._every_n_iter = every_n_iter
self._log_count = every_n_iter
def on_batch_end(self, batch, logs=None):
if self._log_count > 0:
self._log_count -= 1
print("Logging Tensor Callback: dense/kernel:",
model.layers[0].weights[0])
print("Logging Tensor Callback: dense/bias:",
model.layers[0].weights[1])
print("Logging Tensor Callback loss:", logs["loss"])
else:
self._log_count -= self._every_n_iter
완료되면 새로운 콜백인 StopAtStepCallback
및 LoggingTensorCallback
을 Model.fit의 callbacks
매개변수에 Model.fit
.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
# The training will stop after 2 steps, and the weights/loss will also be logged.
model.fit(dataset, callbacks=[StopAtStepCallback(stop_step=2),
LoggingTensorCallback(every_n_iter=2)])
Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.19878045], [-0.5605426 ]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.04971096], dtype=float32)> Logging Tensor Callback loss: 2.143622636795044 1/3 [=========>....................] - ETA: 0s - loss: 2.1436 stop training now Logging Tensor Callback: dense/kernel: <tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy= array([[-0.1512619], [-0.5139848]], dtype=float32)> Logging Tensor Callback: dense/bias: <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.09154248], dtype=float32)> Logging Tensor Callback loss: 3.60127592086792 3/3 [==============================] - 0s 4ms/step - loss: 3.6013 <keras.callbacks.History at 0x7f1a708f12e0>
다음 단계
콜백에 대해 자세히 알아보기:
- API 문서:
tf.keras.callbacks.Callback
- 가이드: 자신만의 콜백 작성하기
- 가이드: 기본 제공 메서드를 사용한 교육 및 평가 ( 콜백 사용 섹션)
다음과 같은 마이그레이션 관련 리소스도 유용할 수 있습니다.
- 조기 중지 마이그레이션 가이드 :
tf.keras.callbacks.EarlyStopping
은 조기 중지 콜백이 내장되어 있습니다. - TensorBoard 마이그레이션 가이드 : TensorBoard는 측정항목을 추적하고 표시할 수 있습니다.
- 지원 로직 마이그레이션 가이드를 통한 교육 :
SessionRunHook
에서 TensorFlow 2의 Keras 콜백까지