TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
Canned(または既製)Estimator は、TensorFlow 1 でさまざまな典型的なユースケースのモデルをトレーニングするための迅速かつ簡単な方法として従来使用されてきました。 TensorFlow 2 は、Keras モデルを介して、それらの多くの単純な近似代用を提供します。 TensorFlow 2 の代用が組み込まれていない Canned Estimator の場合でも、独自の置換をかなり簡単に構築できます。
このガイドでは、TensorFlow 1 の tf.estimator
から派生したモデルを Keras を使用して TensorFlow 2 に移行する方法を示すために、直接相当するものとカスタム置換の例をいくつか紹介します。
すなわち、このガイドには移行の例が含まれています。
- TensorFlow 1 の
tf.estimator
のLinearEstimator
、Classifier
またはRegressor
から、TensorFlow 2 のtf.compat.v1.keras.models.LinearModel
へ - TensorFlow 1 の
tf.estimator
のDNNEstimator
、Classifier
またはRegressor
から、TensorFlow 2 のカスタム Keras DNN ModelKeras へ - TensorFlow 1 の
tf.estimator
のDNNLinearCombinedEstimator
、Classifier
またはRegressor
から、TensorFlow 2 のtf.compat.v1.keras.models.WideDeepModel
へ - TensorFlow 1 の
tf.estimator
のBoostedTreesEstimator
、Classifier
またはRegressor
から、TensorFlow 2 のtfdf.keras.GradientBoostedTreesModel
へ
モデルのトレーニングの一般的な前処理は、特徴量の前処理です。これは、tf.feature_column
を使用して TensorFlow 1 Estimator モデルに対して行われます。TensorFlow 2 での特徴量の前処理の詳細については、特徴量列から Keras 前処理レイヤー API への移行に関するこのガイドをご覧ください。
セットアップ
いくつかの必要な TensorFlow インポートから始めます。
pip install tensorflow_decision_forests
import keras
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_decision_forests as tfdf
2024-01-11 18:23:07.605663: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:23:07.605707: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:23:07.607255: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
標準のタイタニックのデータセットからデモンストレーション用のいくつかの簡単なデータを準備します。
x_train = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
x_eval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
x_train['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_eval['sex'].replace(('male', 'female'), (0, 1), inplace=True)
x_train['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_eval['alone'].replace(('n', 'y'), (0, 1), inplace=True)
x_train['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_eval['class'].replace(('First', 'Second', 'Third'), (1, 2, 3), inplace=True)
x_train.drop(['embark_town', 'deck'], axis=1, inplace=True)
x_eval.drop(['embark_town', 'deck'], axis=1, inplace=True)
y_train = x_train.pop('survived')
y_eval = x_eval.pop('survived')
# Data setup for TensorFlow 1 with `tf.estimator`
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((dict(x_train), y_train)).batch(32)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices((dict(x_eval), y_eval)).batch(32)
FEATURE_NAMES = [
'age', 'fare', 'sex', 'n_siblings_spouses', 'parch', 'class', 'alone'
]
feature_columns = []
for fn in FEATURE_NAMES:
feat_col = tf1.feature_column.numeric_column(fn, dtype=tf.float32)
feature_columns.append(feat_col)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_60274/2801132002.py:16: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
そして、さまざまな TensorFlow 1 Estimator および TensorFlow 2 Keras モデルで使用する単純なサンプルオプティマイザをインスタンス化するメソッドを作成します。
def create_sample_optimizer(tf_version):
if tf_version == 'tf1':
optimizer = lambda: tf.keras.optimizers.legacy.Ftrl(
l1_regularization_strength=0.001,
learning_rate=tf1.train.exponential_decay(
learning_rate=0.1,
global_step=tf1.train.get_global_step(),
decay_steps=10000,
decay_rate=0.9))
elif tf_version == 'tf2':
optimizer = tf.keras.optimizers.legacy.Ftrl(
l1_regularization_strength=0.001,
learning_rate=tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=0.1, decay_steps=10000, decay_rate=0.9))
return optimizer
例 1: LinearEstimator からの移行
TensorFlow 1: LinearEstimator の使用
TensorFlow 1 では、tf.estimator.LinearEstimator
を使用して、回帰および分類問題のベースライン線形モデルを作成できます。
linear_estimator = tf.estimator.LinearEstimator(
head=tf.estimator.BinaryClassHead(),
feature_columns=feature_columns,
optimizer=create_sample_optimizer('tf1'))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_60274/2944250643.py:2: BinaryClassHead.__init__ (from tensorflow_estimator.python.estimator.head.binary_class_head) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_60274/2944250643.py:1: LinearEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.linear) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/linear.py:1124: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmp7pxji_qo INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmp7pxji_qo', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
linear_estimator.train(input_fn=_input_fn, steps=100)
linear_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/ftrl.py:173: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmp7pxji_qo/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 0.6931472, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20... INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmp7pxji_qo/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20... INFO:tensorflow:Loss for final step: 0.552688. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2024-01-11T18:23:15 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmp7pxji_qo/model.ckpt-20 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Inference Time : 2.94012s INFO:tensorflow:Finished evaluation at 2024-01-11-18:23:17 INFO:tensorflow:Saving dict for global step 20: accuracy = 0.70075756, accuracy_baseline = 0.625, auc = 0.75472915, auc_precision_recall = 0.65362054, average_loss = 0.5759378, global_step = 20, label/mean = 0.375, loss = 0.5704812, precision = 0.6388889, prediction/mean = 0.41331062, recall = 0.46464646 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmp7pxji_qo/model.ckpt-20 {'accuracy': 0.70075756, 'accuracy_baseline': 0.625, 'auc': 0.75472915, 'auc_precision_recall': 0.65362054, 'average_loss': 0.5759378, 'label/mean': 0.375, 'loss': 0.5704812, 'precision': 0.6388889, 'prediction/mean': 0.41331062, 'recall': 0.46464646, 'global_step': 20}
TensorFlow 2: Keras LinearModel の使用
TensorFlow 2 では、tf.estimator.LinearEstimator
の代替である Keras tf.compat.v1.keras.models.LinearModel
のインスタンスを作成できます。tf.compat.v1.keras
パスは、互換性のために事前に作成されたモデルが存在することを示すために使用されます。
linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10)
linear_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10 20/20 [==============================] - 1s 2ms/step - loss: 4.3678 - accuracy: 0.6045 Epoch 2/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2211 - accuracy: 0.6715 Epoch 3/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2662 - accuracy: 0.6635 Epoch 4/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1962 - accuracy: 0.6970 Epoch 5/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1922 - accuracy: 0.7225 Epoch 6/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1716 - accuracy: 0.7576 Epoch 7/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1725 - accuracy: 0.7656 Epoch 8/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1801 - accuracy: 0.7608 Epoch 9/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1639 - accuracy: 0.8070 Epoch 10/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1635 - accuracy: 0.8038 9/9 [==============================] - 0s 2ms/step - loss: 0.1796 - accuracy: 0.7462 {'loss': 0.1795806735754013, 'accuracy': 0.7462121248245239}
例 2: DNNEstimator からの移行
TensorFlow 1: DNNEstimator の使用
TensorFlow 1 では、tf.estimator.DNNEstimator
を使用して、回帰および分類問題のベースラインとなるディープニューラルネットワーク(DNN)モデルを作成できます。
dnn_estimator = tf.estimator.DNNEstimator(
head=tf.estimator.BinaryClassHead(),
feature_columns=feature_columns,
hidden_units=[128],
activation_fn=tf.nn.relu,
optimizer=create_sample_optimizer('tf1'))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_60274/1828606501.py:1: DNNEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmprjsdd_8q INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmprjsdd_8q', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
dnn_estimator.train(input_fn=_input_fn, steps=100)
dnn_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 2024-01-11 18:23:20.386178: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } for Tuple type infernce function 0 while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmprjsdd_8q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 3.1896286, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20... INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmprjsdd_8q/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20... INFO:tensorflow:Loss for final step: 0.576495. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2024-01-11T18:23:22 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmprjsdd_8q/model.ckpt-20 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Inference Time : 2.82041s INFO:tensorflow:Finished evaluation at 2024-01-11-18:23:25 INFO:tensorflow:Saving dict for global step 20: accuracy = 0.6931818, accuracy_baseline = 0.625, auc = 0.69877565, auc_precision_recall = 0.60144323, average_loss = 0.6064432, global_step = 20, label/mean = 0.375, loss = 0.5960985, precision = 0.60714287, prediction/mean = 0.3887555, recall = 0.5151515 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmprjsdd_8q/model.ckpt-20 {'accuracy': 0.6931818, 'accuracy_baseline': 0.625, 'auc': 0.69877565, 'auc_precision_recall': 0.60144323, 'average_loss': 0.6064432, 'label/mean': 0.375, 'loss': 0.5960985, 'precision': 0.60714287, 'prediction/mean': 0.3887555, 'recall': 0.5151515, 'global_step': 20}
TensorFlow 2: Keras を使用してカスタム DNN モデルを作成する
TensorFlow 2 では、カスタム DNN モデルを作成して、tf.estimator.DNNEstimator
によって生成されたものを置き換えることができ、同様のレベルのユーザー指定のカスタマイズが可能です(例えば、前の例のように、選択したモデルオプティマイザをカスタマイズする機能)。
同様のワークフローを使用して、tf.estimator.experimental.RNNEstimator
を Keras 再帰型ニューラルネットワーク(RNN)モデルに置き換えることができます。Keras は、tf.keras.layers.RNN
、tf.keras.layers.LSTM
、および tf.keras.layers.GRU
によって、多数の組み込みのカスタマイズ可能な選択肢を提供します。詳細については、Keras を使用した RNN ガイドの組み込み RNN レイヤー: 簡単な例をご覧ください。
dnn_model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=create_sample_optimizer('tf2'), metrics=['accuracy'])
dnn_model.fit(x_train, y_train, epochs=10)
dnn_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10 20/20 [==============================] - 0s 2ms/step - loss: 638.0258 - accuracy: 0.5518 Epoch 2/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2738 - accuracy: 0.6459 Epoch 3/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2556 - accuracy: 0.6762 Epoch 4/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1932 - accuracy: 0.6986 Epoch 5/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1908 - accuracy: 0.7337 Epoch 6/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2015 - accuracy: 0.7368 Epoch 7/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1910 - accuracy: 0.7432 Epoch 8/10 20/20 [==============================] - 0s 2ms/step - loss: 0.2264 - accuracy: 0.7289 Epoch 9/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1711 - accuracy: 0.7735 Epoch 10/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1928 - accuracy: 0.7656 9/9 [==============================] - 0s 2ms/step - loss: 0.1969 - accuracy: 0.7121 {'loss': 0.19687658548355103, 'accuracy': 0.7121211886405945}
例 3: DNNLinearCombinedEstimator からの移行
TensorFlow 1: DNNLinearCombinedEstimator の使用
TensorFlow 1 では、tf.estimator.DNNLinearCombinedEstimator
を使用して、線形コンポーネントと DNN コンポーネントの両方のカスタマイズ機能を備えた回帰および分類問題のベースライン結合モデルを作成できます。
optimizer = create_sample_optimizer('tf1')
combined_estimator = tf.estimator.DNNLinearCombinedEstimator(
head=tf.estimator.BinaryClassHead(),
# Wide settings
linear_feature_columns=feature_columns,
linear_optimizer=optimizer,
# Deep settings
dnn_feature_columns=feature_columns,
dnn_hidden_units=[128],
dnn_optimizer=optimizer)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_60274/1505653152.py:3: DNNLinearCombinedEstimatorV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn_linear_combined) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmptxl5zn0u INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmptxl5zn0u', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
combined_estimator.train(input_fn=_input_fn, steps=100)
combined_estimator.evaluate(input_fn=_eval_input_fn, steps=10)
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmptxl5zn0u/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 3.1158555, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 20... INFO:tensorflow:Saving checkpoints for 20 into /tmpfs/tmp/tmptxl5zn0u/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 20... INFO:tensorflow:Loss for final step: 0.5427769. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2024-01-11T18:23:29 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmptxl5zn0u/model.ckpt-20 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Inference Time : 2.84702s INFO:tensorflow:Finished evaluation at 2024-01-11-18:23:32 INFO:tensorflow:Saving dict for global step 20: accuracy = 0.7121212, accuracy_baseline = 0.625, auc = 0.7426691, auc_precision_recall = 0.6507447, average_loss = 0.57682097, global_step = 20, label/mean = 0.375, loss = 0.56425, precision = 0.6419753, prediction/mean = 0.40049776, recall = 0.5252525 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 20: /tmpfs/tmp/tmptxl5zn0u/model.ckpt-20 {'accuracy': 0.7121212, 'accuracy_baseline': 0.625, 'auc': 0.7426691, 'auc_precision_recall': 0.6507447, 'average_loss': 0.57682097, 'label/mean': 0.375, 'loss': 0.56425, 'precision': 0.6419753, 'prediction/mean': 0.40049776, 'recall': 0.5252525, 'global_step': 20}
TensorFlow 2: Keras WideDeepModel の使用
TensorFlow 2 では、Keras の tf.compat.v1.keras.models.WideDeepModel
インスタンスを作成して、tf.estimator.DNNLinearCombinedEstimator
によって生成されたものを置き換えることができ、同様のレベルのユーザー指定のカスタマイズが可能です(例えば、前の例のように、選択したモデルオプティマイザをカスタマイズする機能)。
この WideDeepModel
は、構成要素である LinearModel
とカスタム DNN モデルに基づいて構築されます。どちらも前の 2 つの例で説明されています。必要に応じて、組み込みの LinearModel
の代わりにカスタム線形モデルを使用することもできます。
Canned Estimator の代わりに独自のモデルを構築したい場合は、 Keras Sequential モデルガイドをご覧ください。カスタムトレーニングとオプティマイザの詳細については、カスタムトレーニング: チュートリアルガイドをご覧ください。
# Create LinearModel and DNN Model as in Examples 1 and 2
optimizer = create_sample_optimizer('tf2')
linear_model = tf.compat.v1.keras.experimental.LinearModel()
linear_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
linear_model.fit(x_train, y_train, epochs=10, verbose=0)
dnn_model = tf.keras.models.Sequential(
[tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1)])
dnn_model.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])
combined_model = tf.compat.v1.keras.experimental.WideDeepModel(linear_model,
dnn_model)
combined_model.compile(
optimizer=[optimizer, optimizer], loss='mse', metrics=['accuracy'])
combined_model.fit([x_train, x_train], y_train, epochs=10)
combined_model.evaluate(x_eval, y_eval, return_dict=True)
Epoch 1/10 20/20 [==============================] - 0s 3ms/step - loss: 690.1404 - accuracy: 0.5407 Epoch 2/10 20/20 [==============================] - 0s 2ms/step - loss: 0.4099 - accuracy: 0.6826 Epoch 3/10 20/20 [==============================] - 0s 3ms/step - loss: 0.2995 - accuracy: 0.6842 Epoch 4/10 20/20 [==============================] - 0s 2ms/step - loss: 0.3216 - accuracy: 0.7097 Epoch 5/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1901 - accuracy: 0.7512 Epoch 6/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1709 - accuracy: 0.7687 Epoch 7/10 20/20 [==============================] - 0s 3ms/step - loss: 0.1696 - accuracy: 0.7911 Epoch 8/10 20/20 [==============================] - 0s 3ms/step - loss: 0.1612 - accuracy: 0.7911 Epoch 9/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1543 - accuracy: 0.7974 Epoch 10/10 20/20 [==============================] - 0s 2ms/step - loss: 0.1520 - accuracy: 0.8070 9/9 [==============================] - 0s 2ms/step - loss: 0.2121 - accuracy: 0.7045 {'loss': 0.212116077542305, 'accuracy': 0.7045454382896423}
例 4: BoostedTreesEstimator からの移行
TensorFlow 1: BoostedTreesEstimator の使用
TensorFlow 1 では、tf.estimator.BoostedTreesEstimator
を使用してベースラインを作成し、回帰および分類問題のデシジョンツリーのアンサンブルを使用してベースライン勾配ブースティングモデルを作成できました。この機能は、TensorFlow 2 には含まれなくなりました。
bt_estimator = tf1.estimator.BoostedTreesEstimator(
head=tf.estimator.BinaryClassHead(),
n_batches_per_layer=1,
max_depth=10,
n_trees=1000,
feature_columns=feature_columns)
bt_estimator.train(input_fn=_input_fn, steps=1000)
bt_estimator.evaluate(input_fn=_eval_input_fn, steps=100)
TensorFlow 2: TensorFlow Decision Forests の使用
TensorFlow 2 では、tf.estimator.BoostedTreesEstimator
は
TensorFlow Decision Forests パッケージの tfdf.keras.GradientBoostedTreesModel に置き換えられました。
TensorFlow Decision Forests は、tf.estimator.BoostedTreesEstimator
に比べて、特に品質、速度、使いやすさ、および柔軟性に関してさまざまな利点を提供します。TensorFlow Decision Forests について学ぶには、初心者のための colab から始めてください。
次の例は、TensorFlow 2 を使用して勾配ブーストツリーモデルをトレーニングする方法を示しています。
TensorFlow Decision Forests のインストール
pip install tensorflow_decision_forests
TensorFlow データセットを作成します。Decision Forests は多くの種類の特徴量をネイティブにサポートしており、前処理を必要としないことに注意してください。
train_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')
eval_dataframe = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')
# Convert the Pandas Dataframes into TensorFlow datasets.
train_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(train_dataframe, label="survived")
eval_dataset = tfdf.keras.pd_dataframe_to_tf_dataset(eval_dataframe, label="survived")
train_dataset
データセットでモデルをトレーニングします。
# Use the default hyper-parameters of the model.
gbt_model = tfdf.keras.GradientBoostedTreesModel()
gbt_model.fit(train_dataset)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. Use /tmpfs/tmp/tmpxdhi_yi2 as temporary training directory Reading training dataset... [WARNING 24-01-11 18:23:37.1620 UTC gradient_boosted_trees.cc:1886] "goss_alpha" set but "sampling_method" not equal to "GOSS". [WARNING 24-01-11 18:23:37.1620 UTC gradient_boosted_trees.cc:1897] "goss_beta" set but "sampling_method" not equal to "GOSS". [WARNING 24-01-11 18:23:37.1621 UTC gradient_boosted_trees.cc:1911] "selective_gradient_boosting_ratio" set but "sampling_method" not equal to "SELGB". Training dataset read in 0:00:03.672133. Found 627 examples. Training model... Model trained in 0:00:00.214528 Compiling model... [INFO 24-01-11 18:23:41.0575 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmpxdhi_yi2/model/ with prefix d8518c499d0543a1 [INFO 24-01-11 18:23:41.0611 UTC quick_scorer_extended.cc:903] The binary was compiled without AVX2 support, but your CPU supports it. Enable it for faster model inference. [INFO 24-01-11 18:23:41.0613 UTC abstract_model.cc:1344] Engine "GradientBoostedTreesQuickScorerExtended" built [INFO 24-01-11 18:23:41.0613 UTC kernel.cc:1061] Use fast generic engine Model compiled. <keras.src.callbacks.History at 0x7ff33c472100>
eval_dataset
データセットでモデルの品質を評価します。
gbt_model.compile(metrics=['accuracy'])
gbt_evaluation = gbt_model.evaluate(eval_dataset, return_dict=True)
print(gbt_evaluation)
1/1 [==============================] - 0s 295ms/step - loss: 0.0000e+00 - accuracy: 0.8295 {'loss': 0.0, 'accuracy': 0.8295454382896423}
勾配ブーストツリーは、TensorFlow Decision Forests で利用できる多くのデシジョンフォレストアルゴリズムの 1 つにすぎません。たとえば、Random Forests(tfdf.keras.GradientBoostedTreesModel として利用可能であり、オーバーフィッティングに対して非常に耐性があります)に対して、CART(tfdf.keras.CartModel として利用可能)はモデルの解釈に最適です。
次の例では、Random Forest モデルをトレーニングしてプロットします。
# Train a Random Forest model
rf_model = tfdf.keras.RandomForestModel()
rf_model.fit(train_dataset)
# Evaluate the Random Forest model
rf_model.compile(metrics=['accuracy'])
rf_evaluation = rf_model.evaluate(eval_dataset, return_dict=True)
print(rf_evaluation)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. Use /tmpfs/tmp/tmptazbt_4s as temporary training directory Reading training dataset... Training dataset read in 0:00:00.188342. Found 627 examples. Training model... Model trained in 0:00:00.182442 Compiling model... [INFO 24-01-11 18:23:42.5392 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmptazbt_4s/model/ with prefix 34932ee267074ee5 [INFO 24-01-11 18:23:42.6346 UTC decision_forest.cc:660] Model loaded with 300 root(s), 34556 node(s), and 9 input feature(s). [INFO 24-01-11 18:23:42.6347 UTC kernel.cc:1061] Use fast generic engine Model compiled. 1/1 [==============================] - 0s 136ms/step - loss: 0.0000e+00 - accuracy: 0.8333 {'loss': 0.0, 'accuracy': 0.8333333134651184}
最後の例では、CART モデルをトレーニングして評価します。
# Train a CART model
cart_model = tfdf.keras.CartModel()
cart_model.fit(train_dataset)
# Plot the CART model
tfdf.model_plotter.plot_model_in_colab(cart_model, max_depth=2)
Warning: The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. WARNING:absl:The `num_threads` constructor argument is not set and the number of CPU is os.cpu_count()=32 > 32. Setting num_threads to 32. Set num_threads manually to use more than 32 cpus. Use /tmpfs/tmp/tmp1ugurarv as temporary training directory Reading training dataset... Training dataset read in 0:00:00.190634. Found 627 examples. Training model... Model trained in 0:00:00.016511 Compiling model... Model compiled. [INFO 24-01-11 18:23:43.1704 UTC kernel.cc:1233] Loading model from path /tmpfs/tmp/tmp1ugurarv/model/ with prefix e4c1cc3f3d714ea2 [INFO 24-01-11 18:23:43.1707 UTC decision_forest.cc:660] Model loaded with 1 root(s), 21 node(s), and 5 input feature(s). [INFO 24-01-11 18:23:43.1707 UTC kernel.cc:1061] Use fast generic engine