TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
警告: 新しいコードには Estimators は推奨されません。Estimators は
v1.Session
スタイルのコードを実行しますが、これは正しく記述するのはより難しく、特に TF 2 コードと組み合わせると予期しない動作をする可能性があります。Estimators は、互換性保証の対象となりますが、セキュリティの脆弱性以外の修正は行われません。詳細については、移行ガイドを参照してください。
このチュートリアルでは、Estimator を使用して、TensorFlow でアヤメの分類問題を解決する方法を示します。Estimator は、レガシー TensorFlow における完全なモデルの高レベルの表現です。詳細については、Estimatorをご覧ください。
注意: TensorFlow 2.0 では、Keras API でも同じタスクを実行でき、より学習しやすい API とされています。はじめて学習する場合は、Keras から着手することをお勧めします。
まず最初に
始めるには、最初に TensorFlow と必要となる多数のライブラリをインポートします。
import tensorflow as tf
import pandas as pd
2022-12-15 02:40:34.596324: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-15 02:40:34.596441: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-15 02:40:34.596452: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
データセット
このドキュメントのサンプルプログラムは、アヤメの花を、萼片と花弁のサイズに基づいて、3 つの品種に分類するモデルを構築してテストします。
モデルのトレーニングには、Iris データセットを使用します。Iris データセットには 4 つの特徴量と 1 つのラベルが含まれます。4 つの特徴量は、次に示す各アヤメの植物学的特性を識別します。
- 萼片の長さ
- 萼片の幅
- 花弁の長さ
- 花弁の幅
この情報に基づき、データを解析する上で役立ついくつかの定数を定義できます。
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
次に、Keras と Pandas を使用して、Iris データセットをダウンロードして解析します。トレーニング用とテスト用に別々のデータセットを維持することに注意してください。
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv 2194/2194 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv 573/573 [==============================] - 0s 0us/step
データを検査し、4 つの浮動小数型の特徴量カラムと 1 つの int32 ラベルがあることを確認します。
train.head()
各データセットに対し、モデルが予測するようにトレーニングされるラベルを分割します。
train_y = train.pop('Species')
test_y = test.pop('Species')
# The label column has now been removed from the features.
train.head()
Estimator を使ったプログラミングの概要
データのセットアップが完了したので、TensorFlow Estimator を使ってモデルを定義できます。Estimator は、tf.estimator.Estimator
から派生したクラスです。TensorFlow は、一群の tf.estimator
(LinearRegressor
など)を提供しており、一般的な ML アルゴリズムを実装することができます。このほか、独自のカスタム Estimator を作成することもできますが、使用し始めには、事前作成済みの Estimator を使用することをお勧めします。
事前作成済みの Estimator に基づいて TensorFlow プログラムを記述するには、次のタスクを実行する必要があります。
- 1 つ以上の入力関数を作成する。
- モデルの特徴量カラムを定義する。
- Estimator をインスタンス化する。特徴量カラムとさまざまなハイパーパラメータを指定します。
- Estimator オブジェクトに 1 つ以上のメソッドを呼び出す。データのソースとして適切な入力関数を渡します。
では、アヤメの分類において、これらのタスクをどのように実装するのか見てみましょう。
入力関数を作成する
トレーニング、評価、および予測を行うためのデータを提供する入力関数を作成する必要があります。
入力関数とは、次の要素タプルを出力する tf.data.Dataset
オブジェクトを返す関数です。
features
- 次のような Python ディクショナリ。- 各キーが特徴量の名前である。
- 各値が、特徴量の値のすべてを含む配列である。
label
- 各サンプルの label の値を含む配列。
入力関数の書式を示すために、単純な実装を次に示します。
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
入力関数を自分で作成すれば、features
ディクショナリと label
リストを好みに合わせて生成できるようにすることができますが、あらゆる種類のデータを解析できる TensorFlow の Dataset API を使用することをお勧めします。
Dataset API は、多数の一般的な事例を処理することができます。たとえば、Dataset API を使用すると、大量のファイルのレコードを並列して読み取り、単一のストリームに結合することが簡単に行えます。
この例では事を単純にするために、pandas でデータを読み込み、このメモリ内のデータから入力パイプラインを構築します。
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
特徴量カラムを定義する
特徴量カラムは、特徴量ディクショナリの生の入力データを、モデルがどのように使用すべきかを説明するオブジェクトです。Estimator モデルを作成する際に、モデルが使用する各特徴量を説明する特徴量カラムをモデルに渡します。tf.feature_column
モジュールには、モデルに対してデータを表現するためのオプションが多数含まれています。
Iris については、4 つの生の特徴量は数値であるため、Estimator に対して、これら 4 つの各特徴量を 32 ビットの浮動小数点数型の値として表現するように命令する特徴量カラムを構築します。したがって、特徴カラムを作成するためのコードは、次のようになります。
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
特徴量カラムは、ここに示すものよりもはるかに高度なものに構築することができます。特徴量カラムの詳細については、こちらのガイドをご覧ください。
モデルが生の特徴量をどのように表現するかに関する記述を準備できたので、Estimator を構築することができます。
Estimator をインスタンス化する
アヤメの問題はよく知られた分類問題です。幸いにも、TensorFlow は、次のような事前作成済みの分類子 Estimator を複数用意しています。
tf.estimator.DNNClassifier
: 多クラス分類を実行するディープモデルに使用。tf.estimator.DNNLinearCombinedClassifier
: ワイド&ディープモデルに使用。tf.estimator.LinearClassifier
: 線形モデルに基づく分類子に使用。
アヤメの問題に関しては、tf.estimator.DNNClassifier
が最適な選択肢と言えます。この Estimator をインスタンス化する方法を次に示します。
# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# Two hidden layers of 30 and 10 nodes respectively.
hidden_units=[30, 10],
# The model must choose between 3 classes.
n_classes=3)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpv8gz9mg7 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpv8gz9mg7', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
トレーニングして評価して予測する
Estimator オブジェクトを準備したので、次の項目を行うメソッド呼び出すことができます。
- モデルをトレーニングする。
- トレーニングされたモデルを評価する。
- トレーニングされたモデルを使用して、予測を立てる。
モデルをトレーニングする
次のように、Estimator の train
メソッドを呼び出して、モデルをトレーニングします。
# Train the Model.
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/training_util.py:396: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/optimizers/optimizer_v2/adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 2022-12-15 02:40:40.279374: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpv8gz9mg7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.3653078, step = 0 INFO:tensorflow:global_step/sec: 424.758 INFO:tensorflow:loss = 1.0535867, step = 100 (0.237 sec) INFO:tensorflow:global_step/sec: 551.284 INFO:tensorflow:loss = 0.9845939, step = 200 (0.182 sec) INFO:tensorflow:global_step/sec: 564.881 INFO:tensorflow:loss = 0.93356884, step = 300 (0.177 sec) INFO:tensorflow:global_step/sec: 566.209 INFO:tensorflow:loss = 0.8931673, step = 400 (0.177 sec) INFO:tensorflow:global_step/sec: 566.856 INFO:tensorflow:loss = 0.863848, step = 500 (0.176 sec) INFO:tensorflow:global_step/sec: 545.981 INFO:tensorflow:loss = 0.83251595, step = 600 (0.183 sec) INFO:tensorflow:global_step/sec: 556.842 INFO:tensorflow:loss = 0.79153323, step = 700 (0.180 sec) INFO:tensorflow:global_step/sec: 555.019 INFO:tensorflow:loss = 0.7832278, step = 800 (0.180 sec) INFO:tensorflow:global_step/sec: 545.374 INFO:tensorflow:loss = 0.7552776, step = 900 (0.183 sec) INFO:tensorflow:global_step/sec: 532.68 INFO:tensorflow:loss = 0.7393273, step = 1000 (0.188 sec) INFO:tensorflow:global_step/sec: 548.912 INFO:tensorflow:loss = 0.72828543, step = 1100 (0.182 sec) INFO:tensorflow:global_step/sec: 540.638 INFO:tensorflow:loss = 0.7053199, step = 1200 (0.185 sec) INFO:tensorflow:global_step/sec: 538.089 INFO:tensorflow:loss = 0.6969624, step = 1300 (0.186 sec) INFO:tensorflow:global_step/sec: 543.99 INFO:tensorflow:loss = 0.6601064, step = 1400 (0.184 sec) INFO:tensorflow:global_step/sec: 545.081 INFO:tensorflow:loss = 0.6538592, step = 1500 (0.183 sec) INFO:tensorflow:global_step/sec: 583.538 INFO:tensorflow:loss = 0.6431242, step = 1600 (0.172 sec) INFO:tensorflow:global_step/sec: 582.859 INFO:tensorflow:loss = 0.6337292, step = 1700 (0.171 sec) INFO:tensorflow:global_step/sec: 581.883 INFO:tensorflow:loss = 0.6249995, step = 1800 (0.172 sec) INFO:tensorflow:global_step/sec: 599.331 INFO:tensorflow:loss = 0.6074962, step = 1900 (0.167 sec) INFO:tensorflow:global_step/sec: 585.038 INFO:tensorflow:loss = 0.5954494, step = 2000 (0.171 sec) INFO:tensorflow:global_step/sec: 582.089 INFO:tensorflow:loss = 0.59253395, step = 2100 (0.172 sec) INFO:tensorflow:global_step/sec: 602.62 INFO:tensorflow:loss = 0.5689405, step = 2200 (0.166 sec) INFO:tensorflow:global_step/sec: 588.479 INFO:tensorflow:loss = 0.5602833, step = 2300 (0.170 sec) INFO:tensorflow:global_step/sec: 588.211 INFO:tensorflow:loss = 0.55631864, step = 2400 (0.170 sec) INFO:tensorflow:global_step/sec: 592.785 INFO:tensorflow:loss = 0.54900175, step = 2500 (0.169 sec) INFO:tensorflow:global_step/sec: 580.758 INFO:tensorflow:loss = 0.54516006, step = 2600 (0.172 sec) INFO:tensorflow:global_step/sec: 578.628 INFO:tensorflow:loss = 0.52530795, step = 2700 (0.173 sec) INFO:tensorflow:global_step/sec: 587.328 INFO:tensorflow:loss = 0.5299491, step = 2800 (0.170 sec) INFO:tensorflow:global_step/sec: 579.237 INFO:tensorflow:loss = 0.5184585, step = 2900 (0.173 sec) INFO:tensorflow:global_step/sec: 579.646 INFO:tensorflow:loss = 0.5071718, step = 3000 (0.173 sec) INFO:tensorflow:global_step/sec: 588.019 INFO:tensorflow:loss = 0.4960404, step = 3100 (0.170 sec) INFO:tensorflow:global_step/sec: 577.387 INFO:tensorflow:loss = 0.47985545, step = 3200 (0.173 sec) INFO:tensorflow:global_step/sec: 595.986 INFO:tensorflow:loss = 0.48654804, step = 3300 (0.168 sec) INFO:tensorflow:global_step/sec: 573.446 INFO:tensorflow:loss = 0.48582077, step = 3400 (0.174 sec) INFO:tensorflow:global_step/sec: 578.275 INFO:tensorflow:loss = 0.46541944, step = 3500 (0.173 sec) INFO:tensorflow:global_step/sec: 566.381 INFO:tensorflow:loss = 0.4748811, step = 3600 (0.177 sec) INFO:tensorflow:global_step/sec: 571.906 INFO:tensorflow:loss = 0.47083074, step = 3700 (0.175 sec) INFO:tensorflow:global_step/sec: 574.657 INFO:tensorflow:loss = 0.44040596, step = 3800 (0.174 sec) INFO:tensorflow:global_step/sec: 574.819 INFO:tensorflow:loss = 0.45463592, step = 3900 (0.174 sec) INFO:tensorflow:global_step/sec: 588.25 INFO:tensorflow:loss = 0.44358343, step = 4000 (0.170 sec) INFO:tensorflow:global_step/sec: 585.435 INFO:tensorflow:loss = 0.4394082, step = 4100 (0.171 sec) INFO:tensorflow:global_step/sec: 570.47 INFO:tensorflow:loss = 0.44057947, step = 4200 (0.175 sec) INFO:tensorflow:global_step/sec: 570.751 INFO:tensorflow:loss = 0.43266475, step = 4300 (0.175 sec) INFO:tensorflow:global_step/sec: 573.692 INFO:tensorflow:loss = 0.4180813, step = 4400 (0.174 sec) INFO:tensorflow:global_step/sec: 557.131 INFO:tensorflow:loss = 0.42264342, step = 4500 (0.180 sec) INFO:tensorflow:global_step/sec: 573.329 INFO:tensorflow:loss = 0.42919323, step = 4600 (0.174 sec) INFO:tensorflow:global_step/sec: 588.864 INFO:tensorflow:loss = 0.41696268, step = 4700 (0.170 sec) INFO:tensorflow:global_step/sec: 579.002 INFO:tensorflow:loss = 0.40943825, step = 4800 (0.173 sec) INFO:tensorflow:global_step/sec: 566.232 INFO:tensorflow:loss = 0.410751, step = 4900 (0.177 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000... INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpv8gz9mg7/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000... INFO:tensorflow:Loss for final step: 0.4027862. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7f003e6a3dc0>
Estimator が期待するとおり、引数を取らない入力関数を指定しながら、input_fn
呼び出しを lambda
にラッピングして引数をキャプチャするところに注意してください。steps
引数はメソッドに対して、あるトレーニングステップ数を完了した後にトレーニングを停止するように指定しています。
トレーニングされたモデルを評価する
モデルのトレーニングが完了したので、そのパフォーマンスに関する統計を得ることができます。次のコードブロックは、テストデータに対してトレーニングされたモデルの精度を評価します。
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-12-15T02:40:50 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpv8gz9mg7/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.52079s INFO:tensorflow:Finished evaluation at 2022-12-15-02:40:50 INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.9, average_loss = 0.4802318, global_step = 5000, loss = 0.4802318 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpv8gz9mg7/model.ckpt-5000 Test set accuracy: 0.900
train
メソッドへの呼び出しとは異なり、評価するsteps
引数を渡していません。eval の input_fn
データの単一のエポックのみを返します。
eval_result
ディクショナリには、average_loss
(サンプル当たりの平均損失)、loss
(ミニバッチ当たりの平均損失)、および Estimator の global_step
の値(実行したトレーニングイテレーションの回数)も含まれます。
トレーニングされたモデルから予測(推論)を立てる
良質の評価結果を生み出すトレーニング済みのモデルを準備できました。これから、このトレーニング済みのモデルを使用し、ラベル付けできない測定に基づいてアヤメの品種を予測します。トレーニングと評価と同様に、単一の関数呼び出して予測を行います。
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
def input_fn(features, batch_size=256):
"""An input function for prediction."""
# Convert the inputs to a Dataset without labels.
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
predictions = classifier.predict(
input_fn=lambda: input_fn(predict_x))
predict
メソッドは Python イテラブルを返し、各サンプルの予測結果のディクショナリを生成します。次のコードを使って、予測とその確率を出力します。
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpv8gz9mg7/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. Prediction is "Setosa" (85.5%), expected "Setosa" Prediction is "Versicolor" (47.2%), expected "Versicolor" Prediction is "Virginica" (60.3%), expected "Virginica"