在 tensorFlow.google.cn 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
警告:不建议将 Estimator 用于新代码。Estimator 运行
v1.Session
风格的代码,此类代码更加难以正确编写,并且可能会出现意外行为,尤其是与 TF 2 代码结合使用时。Estimator 确实在我们的兼容性保证范围内,但除了安全漏洞之外不会得到任何修复。请参阅迁移指南以了解详情。
本教程向您展示了如何使用 Estimator 在 TensorFlow 中解决鸢尾花分类问题。Estimator 是完整模型在旧版 TensorFlow 中的高级表示。有关更多详细信息,请参阅 Estimator。
注:在 TensorFlow 2.0 中,Keras API 可以完成这些相同的任务,并且被认为是一个更容易学习的 API。如果您刚入门,建议您从 Keras 开始。
首先要做的事
为了开始,您将首先导入 Tensorflow 和一系列您需要的库。
import tensorflow as tf
import pandas as pd
2023-11-07 19:05:32.884593: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 19:05:32.884639: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 19:05:32.886222: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
数据集
本文档中的示例程序构建并测试了一个模型,该模型根据花萼和花瓣的大小将鸢尾花分成三种物种。
您将使用鸢尾花数据集训练模型。该数据集包括四个特征和一个标签。这四个特征确定了单个鸢尾花的以下植物学特征:
- 花萼长度
- 花萼宽度
- 花瓣长度
- 花瓣宽度
根据这些信息,您可以定义一些有用的常量来解析数据:
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
接下来,使用 Keras 与 Pandas 下载并解析鸢尾花数据集。注意为训练和测试保留不同的数据集。
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv 2194/2194 [==============================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv 573/573 [==============================] - 0s 0us/step
通过检查数据您可以发现有四列浮点型特征和一列 int32 型标签。
train.head()
对于每个数据集都分割出标签,模型将被训练来预测这些标签。
train_y = train.pop('Species')
test_y = test.pop('Species')
# The label column has now been removed from the features.
train.head()
Estimator 编程概述
现在您已经设置了数据,可以使用 TensorFlow Estimator 定义模型。 Estimator 是从 tf.estimator.Estimator
派生的任何类。TensorFlow 提供了一组 tf.estimator
(例如 LinearRegressor
)来实现常见的 ML 算法。除此之外,您可以编写自己的自定义 Estimator。建议在刚开始时使用预制的 Estimator。
为了编写基于预创建的 Estimator 的 Tensorflow 项目,您必须完成以下工作:
- 创建一个或多个输入函数
- 定义模型的特征列
- 实例化一个 Estimator,指定特征列和各种超参数。
- 在 Estimator 对象上调用一个或多个方法,传递合适的输入函数以作为数据源。
我们来看看这些任务是如何在鸢尾花分类中实现的。
创建输入函数
您必须创建输入函数来提供用于训练、评估和预测的数据。
输入函数是一个返回 tf.data.Dataset
对象的函数,此对象会输出下列含两个元素的元组:
为了向您展示输入函数的格式,请查看下面这个简单的实现:
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
您的输入函数可以用您喜欢的任何方式生成 features
字典和label
列表。但是,推荐使用 TensorFlow 的 Dataset API,它可以解析各种数据。
Dataset API 可以为您处理很多常见情况。例如,使用 Dataset API,您可以轻松地从大量文件中并行读取记录,并将它们合并为单个数据流。
为了简化此示例,我们将使用 pandas 加载数据,并利用此内存数据构建输入管道。
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
定义特征列(feature columns)
特征列(feature columns)是一个对象,用于描述模型应该如何使用特征字典中的原始输入数据。当您构建一个 Estimator 模型的时候,您会向其传递一个特征列的列表,其中包含您希望模型使用的每个特征。tf.feature_column
模块提供了许多为模型表示数据的选项。
对于鸢尾花,4 个原始特征是数值,因此您将构建一个特征列列表来告诉 Estimator 模型将四个特征中的每一个表示为 32 位浮点值。因此,创建特征列的代码为:
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_420919/1593920324.py:4: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model.
特征列可能比这里显示的要复杂得多。您可以在此指南中阅读有关特征列的更多信息。
我们已经介绍了如何使模型表示原始特征,现在您可以构建 Estimator 了。
实例化 Estimator
鸢尾花为题是一个经典的分类问题。幸运的是,Tensorflow 提供了几个预创建的 Estimator 分类器,其中包括:
tf.estimator.DNNClassifier
用于多类别分类的深度模型tf.estimator.DNNLinearCombinedClassifier
用于广度与深度模型tf.estimator.LinearClassifier
用于基于线性模型的分类器
对于鸢尾花问题,tf.estimator.DNNClassifier
似乎是最好的选择。您可以这样实例化该 Estimator:
# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# Two hidden layers of 30 and 10 nodes respectively.
hidden_units=[30, 10],
# The model must choose between 3 classes.
n_classes=3)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_420919/2221267581.py:2: DNNClassifierV2.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/head_utils.py:59: MultiClassHead.__init__ (from tensorflow_estimator.python.estimator.head.multi_class_head) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:759: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpoulu7cx6 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpoulu7cx6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
训练、评估和预测
我们已经有一个 Estimator 对象,现在可以调用方法来执行下列操作:
- 训练模型。
- 评估经过训练的模型。
- 使用经过训练的模型进行预测。
训练模型
通过调用 Estimator 的 Train
方法来训练模型,如下所示:
# Train the Model.
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:385: StopAtStepHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/optimizers/legacy/adagrad.py:93: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 2023-11-07 19:05:37.897085: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } for Tuple type infernce function 0 while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpoulu7cx6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 1.5073682, step = 0 INFO:tensorflow:global_step/sec: 439.465 INFO:tensorflow:loss = 1.0172048, step = 100 (0.229 sec) INFO:tensorflow:global_step/sec: 606.207 INFO:tensorflow:loss = 0.88746536, step = 200 (0.165 sec) INFO:tensorflow:global_step/sec: 616.132 INFO:tensorflow:loss = 0.8561147, step = 300 (0.162 sec) INFO:tensorflow:global_step/sec: 609.728 INFO:tensorflow:loss = 0.8380361, step = 400 (0.164 sec) INFO:tensorflow:global_step/sec: 612.992 INFO:tensorflow:loss = 0.83253044, step = 500 (0.163 sec) INFO:tensorflow:global_step/sec: 613.225 INFO:tensorflow:loss = 0.7997909, step = 600 (0.163 sec) INFO:tensorflow:global_step/sec: 620.106 INFO:tensorflow:loss = 0.8083163, step = 700 (0.161 sec) INFO:tensorflow:global_step/sec: 600.24 INFO:tensorflow:loss = 0.7587497, step = 800 (0.166 sec) INFO:tensorflow:global_step/sec: 622.992 INFO:tensorflow:loss = 0.76957697, step = 900 (0.160 sec) INFO:tensorflow:global_step/sec: 608.274 INFO:tensorflow:loss = 0.7207211, step = 1000 (0.164 sec) INFO:tensorflow:global_step/sec: 613.67 INFO:tensorflow:loss = 0.7105023, step = 1100 (0.163 sec) INFO:tensorflow:global_step/sec: 623.223 INFO:tensorflow:loss = 0.7275356, step = 1200 (0.161 sec) INFO:tensorflow:global_step/sec: 584.544 INFO:tensorflow:loss = 0.7416762, step = 1300 (0.171 sec) INFO:tensorflow:global_step/sec: 595.923 INFO:tensorflow:loss = 0.716245, step = 1400 (0.168 sec) INFO:tensorflow:global_step/sec: 594.104 INFO:tensorflow:loss = 0.69990337, step = 1500 (0.168 sec) INFO:tensorflow:global_step/sec: 596.8 INFO:tensorflow:loss = 0.69416165, step = 1600 (0.168 sec) INFO:tensorflow:global_step/sec: 597.293 INFO:tensorflow:loss = 0.67331016, step = 1700 (0.167 sec) INFO:tensorflow:global_step/sec: 601.12 INFO:tensorflow:loss = 0.6699522, step = 1800 (0.166 sec) INFO:tensorflow:global_step/sec: 589.327 INFO:tensorflow:loss = 0.66161495, step = 1900 (0.170 sec) INFO:tensorflow:global_step/sec: 586.803 INFO:tensorflow:loss = 0.6554887, step = 2000 (0.170 sec) INFO:tensorflow:global_step/sec: 587.731 INFO:tensorflow:loss = 0.6613943, step = 2100 (0.170 sec) INFO:tensorflow:global_step/sec: 601.215 INFO:tensorflow:loss = 0.6285989, step = 2200 (0.166 sec) INFO:tensorflow:global_step/sec: 603.741 INFO:tensorflow:loss = 0.64100504, step = 2300 (0.166 sec) INFO:tensorflow:global_step/sec: 604.138 INFO:tensorflow:loss = 0.62196255, step = 2400 (0.165 sec) INFO:tensorflow:global_step/sec: 599.965 INFO:tensorflow:loss = 0.59547615, step = 2500 (0.167 sec) INFO:tensorflow:global_step/sec: 593.705 INFO:tensorflow:loss = 0.5903188, step = 2600 (0.168 sec) INFO:tensorflow:global_step/sec: 604.55 INFO:tensorflow:loss = 0.616672, step = 2700 (0.165 sec) INFO:tensorflow:global_step/sec: 627.945 INFO:tensorflow:loss = 0.60870504, step = 2800 (0.159 sec) INFO:tensorflow:global_step/sec: 623.905 INFO:tensorflow:loss = 0.59756136, step = 2900 (0.160 sec) INFO:tensorflow:global_step/sec: 613.108 INFO:tensorflow:loss = 0.5934744, step = 3000 (0.163 sec) INFO:tensorflow:global_step/sec: 616.35 INFO:tensorflow:loss = 0.59139955, step = 3100 (0.162 sec) INFO:tensorflow:global_step/sec: 599.788 INFO:tensorflow:loss = 0.584731, step = 3200 (0.167 sec) INFO:tensorflow:global_step/sec: 594.577 INFO:tensorflow:loss = 0.5786096, step = 3300 (0.168 sec) INFO:tensorflow:global_step/sec: 605.614 INFO:tensorflow:loss = 0.58198833, step = 3400 (0.165 sec) INFO:tensorflow:global_step/sec: 594.772 INFO:tensorflow:loss = 0.57257384, step = 3500 (0.168 sec) INFO:tensorflow:global_step/sec: 585.277 INFO:tensorflow:loss = 0.5604176, step = 3600 (0.171 sec) INFO:tensorflow:global_step/sec: 604.04 INFO:tensorflow:loss = 0.550858, step = 3700 (0.166 sec) INFO:tensorflow:global_step/sec: 585.579 INFO:tensorflow:loss = 0.57899547, step = 3800 (0.171 sec) INFO:tensorflow:global_step/sec: 593.479 INFO:tensorflow:loss = 0.54325897, step = 3900 (0.168 sec) INFO:tensorflow:global_step/sec: 605.619 INFO:tensorflow:loss = 0.57464546, step = 4000 (0.165 sec) INFO:tensorflow:global_step/sec: 615.744 INFO:tensorflow:loss = 0.54382163, step = 4100 (0.162 sec) INFO:tensorflow:global_step/sec: 612.668 INFO:tensorflow:loss = 0.5404015, step = 4200 (0.163 sec) INFO:tensorflow:global_step/sec: 589.679 INFO:tensorflow:loss = 0.5463786, step = 4300 (0.169 sec) INFO:tensorflow:global_step/sec: 587.872 INFO:tensorflow:loss = 0.54900044, step = 4400 (0.170 sec) INFO:tensorflow:global_step/sec: 602.697 INFO:tensorflow:loss = 0.52490914, step = 4500 (0.166 sec) INFO:tensorflow:global_step/sec: 601.161 INFO:tensorflow:loss = 0.51717925, step = 4600 (0.166 sec) INFO:tensorflow:global_step/sec: 615.13 INFO:tensorflow:loss = 0.51362664, step = 4700 (0.163 sec) INFO:tensorflow:global_step/sec: 620.598 INFO:tensorflow:loss = 0.5242693, step = 4800 (0.161 sec) INFO:tensorflow:global_step/sec: 602.749 INFO:tensorflow:loss = 0.5284901, step = 4900 (0.166 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000... INFO:tensorflow:Saving checkpoints for 5000 into /tmpfs/tmp/tmpoulu7cx6/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000... INFO:tensorflow:Loss for final step: 0.512492. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7fa8337c7ac0>
注意将 input_fn
调用封装在 lambda
中以获取参数,同时提供不带参数的输入函数,如 Estimator 所预期的那样。step
参数告知该方法在训练多少步后停止训练。
评估经过训练的模型
现在模型已经经过训练,您可以获取一些关于模型性能的统计信息。代码块将在测试数据上对经过训练的模型的准确率(accuracy)进行评估:
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2023-11-07T19:05:47 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpoulu7cx6/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.70216s INFO:tensorflow:Finished evaluation at 2023-11-07-19:05:47 INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.53333336, average_loss = 0.6654332, global_step = 5000, loss = 0.6654332 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmpfs/tmp/tmpoulu7cx6/model.ckpt-5000 Test set accuracy: 0.533
与对 train
方法的调用不同,我们没有传递 steps
参数来进行评估。用于评估的 input_fn
只生成一个 epoch 的数据。
eval_result
字典亦包含 average_loss
(每个样本的平均误差),loss
(每个 mini-batch 的平均误差)与 Estimator 的 global_step
(经历的训练迭代次数)值。
利用经过训练的模型进行预测(推理)
我们已经有一个经过训练的模型,可以生成准确的评估结果。我们现在可以使用经过训练的模型,根据一些无标签测量结果预测鸢尾花的品种。与训练和评估一样,我们使用单个函数调用进行预测:
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
def input_fn(features, batch_size=256):
"""An input function for prediction."""
# Convert the inputs to a Dataset without labels.
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
predictions = classifier.predict(
input_fn=lambda: input_fn(predict_x))
predict
方法返回一个 Python 可迭代对象,为每个样本生成一个预测结果字典。以下代码输出了一些预测及其概率:
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/base_head.py:786: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/head/multi_class_head.py:455: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpoulu7cx6/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. Prediction is "Setosa" (78.5%), expected "Setosa" Prediction is "Virginica" (40.7%), expected "Versicolor" Prediction is "Virginica" (75.4%), expected "Virginica"