TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
警告: 新しいコードには Estimators は推奨されません。Estimators は
v1.Session
スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。
概要
TensorFlow Estimator は、TensorFlow でサポートされており、新規または既存の tf.keras
モデルから作成することができます。このチュートリアルには、このプロセスの完全な最小限の例が含まれます。
注意: Keras モデルがある場合は、Estimator に変換せずに、直接 tf.distribute
ストラテジーで使用することができます。したがって、model_to_estimator
は推奨されなくなりました。
セットアップ
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
2022-12-15 02:45:18.105988: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-15 02:45:18.106099: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-15 02:45:18.106110: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
単純な Keras モデルを作成する。
Keras では、レイヤーを組み合わせてモデルを構築します。モデルは(通常)レイヤーのグラフです。最も一般的なモデルのタイプはレイヤーのスタックである tf.keras.Sequential
モデルです。
単純で完全に接続されたネットワーク(多層パーセプトロン)を構築するには、以下を実行します。
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(3)
])
モデルをコンパイルして要約を取得します。
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam')
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 80 dropout (Dropout) (None, 16) 0 dense_1 (Dense) (None, 3) 51 ================================================================= Total params: 131 Trainable params: 131 Non-trainable params: 0 _________________________________________________________________
入力関数を作成する
Datasets API を使用して、大規模なデータセットまたはマルチデバイストレーニングにスケーリングします。
Estimator には、いつどのように入力パイプラインが構築されるのかを制御する必要があります。これを行えるようにするには、"入力関数" または input_fn
が必要です。Estimator
は引数なしでこの関数を呼び出します。input_fn
は、tf.data.Dataset
を返す必要があります。
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
input_fn
をテストします。
for features_batch, labels_batch in input_fn().take(1):
print(features_batch)
print(labels_batch)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 {'dense_input': <tf.Tensor: shape=(32, 4), dtype=float32, numpy= array([[5.1, 3.4, 1.5, 0.2], [7.7, 3. , 6.1, 2.3], [5.7, 2.8, 4.5, 1.3], [6.8, 3.2, 5.9, 2.3], [5.2, 3.4, 1.4, 0.2], [5.6, 2.9, 3.6, 1.3], [5.5, 2.6, 4.4, 1.2], [5.5, 2.4, 3.7, 1. ], [4.6, 3.4, 1.4, 0.3], [7.7, 2.8, 6.7, 2. ], [7. , 3.2, 4.7, 1.4], [4.6, 3.2, 1.4, 0.2], [6.5, 3. , 5.2, 2. ], [5.5, 4.2, 1.4, 0.2], [5.4, 3.9, 1.3, 0.4], [5. , 3.5, 1.3, 0.3], [5.1, 3.8, 1.5, 0.3], [4.8, 3. , 1.4, 0.1], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [6.7, 3.3, 5.7, 2.1], [7.9, 3.8, 6.4, 2. ], [6.7, 3. , 5.2, 2.3], [5.8, 4. , 1.2, 0.2], [6.3, 2.5, 5. , 1.9], [5. , 3. , 1.6, 0.2], [6.9, 3.1, 5.1, 2.3], [6.1, 3. , 4.6, 1.4], [5.8, 2.7, 4.1, 1. ], [5.2, 2.7, 3.9, 1.4], [6.7, 3. , 5. , 1.7], [5.7, 2.6, 3.5, 1. ]], dtype=float32)>} tf.Tensor([0 2 1 2 0 1 1 1 0 2 1 0 2 0 0 0 0 0 2 2 2 2 2 0 2 0 2 1 1 1 1 1], shape=(32,), dtype=int64)
tf.keras モデルから Estimator を作成する。
tf.keras.Model
は、tf.estimator
API を使って、tf.keras.estimator.model_to_estimator
を持つ tf.estimator.Estimator
オブジェクトにモデルを変換することで、トレーニングすることができます。
import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using the Keras model provided. INFO:tensorflow:Using the Keras model provided. WARNING:absl:You are using `tf.keras.optimizers.experimental.Optimizer` in TF estimator, which only supports `tf.keras.optimizers.legacy.Optimizer`. Automatically converting your optimizer to `tf.keras.optimizers.legacy.Optimizer`. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/backend.py:451: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn( INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp9h0gesw0', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} 2022-12-15 02:45:25.103464: W tensorflow/c/c_api.cc:291] Operation '{name:'training/Adam/dense_1/bias/v/Assign' id:218 op device:{requested: '', assigned: ''} def:{ { {node training/Adam/dense_1/bias/v/Assign} } = AssignVariableOp[_has_manual_control_dependencies=true, dtype=DT_FLOAT, validate_shape=false](training/Adam/dense_1/bias/v, training/Adam/dense_1/bias/v/Initializer/zeros)} }' was changed by setting attribute after it was run by a session. This mutation will have no effect, and will trigger an error in the future. Either don't modify nodes after running them or create a new session. INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp9h0gesw0', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Estimator をトレーニングして評価します。
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmp9h0gesw0/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmpfs/tmp/tmp9h0gesw0/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmp9h0gesw0/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmpfs/tmp/tmp9h0gesw0/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 4 variables. INFO:tensorflow:Warm-started 4 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9h0gesw0/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp9h0gesw0/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 2.432807, step = 0 INFO:tensorflow:loss = 2.432807, step = 0 INFO:tensorflow:global_step/sec: 44.6531 INFO:tensorflow:global_step/sec: 44.6531 INFO:tensorflow:loss = 1.0671942, step = 100 (2.241 sec) INFO:tensorflow:loss = 1.0671942, step = 100 (2.241 sec) INFO:tensorflow:global_step/sec: 46.4651 INFO:tensorflow:global_step/sec: 46.4651 INFO:tensorflow:loss = 0.82943535, step = 200 (2.152 sec) INFO:tensorflow:loss = 0.82943535, step = 200 (2.152 sec) INFO:tensorflow:global_step/sec: 46.1527 INFO:tensorflow:global_step/sec: 46.1527 INFO:tensorflow:loss = 0.6606032, step = 300 (2.167 sec) INFO:tensorflow:loss = 0.6606032, step = 300 (2.167 sec) INFO:tensorflow:global_step/sec: 46.4966 INFO:tensorflow:global_step/sec: 46.4966 INFO:tensorflow:loss = 0.56208, step = 400 (2.151 sec) INFO:tensorflow:loss = 0.56208, step = 400 (2.151 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmp9h0gesw0/model.ckpt. INFO:tensorflow:Saving checkpoints for 500 into /tmpfs/tmp/tmp9h0gesw0/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Loss for final step: 0.5922533. INFO:tensorflow:Loss for final step: 0.5922533. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/training_v1.py:2333: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. updates = self.state_updates INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-15T02:45:38 INFO:tensorflow:Starting evaluation at 2022-12-15T02:45:38 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9h0gesw0/model.ckpt-500 INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp9h0gesw0/model.ckpt-500 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.50137s INFO:tensorflow:Inference Time : 0.50137s INFO:tensorflow:Finished evaluation at 2022-12-15-02:45:39 INFO:tensorflow:Finished evaluation at 2022-12-15-02:45:39 INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.5014641 INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.5014641 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmp9h0gesw0/model.ckpt-500 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmpfs/tmp/tmp9h0gesw0/model.ckpt-500 Eval result: {'loss': 0.5014641, 'global_step': 500}