ดูบน TensorFlow.org | ทำงานใน Google Colab | ดูแหล่งที่มาบน GitHub | ดาวน์โหลดโน๊ตบุ๊ค |
บทช่วยสอนนี้แสดงวิธีแก้ปัญหาการจำแนก Iris ใน TensorFlow โดยใช้ Estimators Estimator คือการนำเสนอ TensorFlow ในระดับสูงแบบเดิมของแบบจำลองที่สมบูรณ์ สำหรับรายละเอียดเพิ่มเติม โปรดดูที่ เครื่องมือประมาณการ
สิ่งแรกก่อน
ในการเริ่มต้น คุณจะต้องนำเข้า TensorFlow และห้องสมุดจำนวนหนึ่งที่คุณต้องการ
import tensorflow as tf
import pandas as pd
ชุดข้อมูล
โปรแกรมตัวอย่างในเอกสารนี้สร้างและทดสอบแบบจำลองที่จำแนกดอกไอริสออกเป็นสามสายพันธุ์ตามขนาดของ กลีบเลี้ยง และ กลีบดอก
คุณจะฝึกโมเดลโดยใช้ชุดข้อมูล Iris ชุดข้อมูล Iris ประกอบด้วยสี่คุณสมบัติและหนึ่ง ป้ายกำกับ คุณสมบัติทั้งสี่ระบุลักษณะทางพฤกษศาสตร์ของดอกไอริสแต่ละดอกดังต่อไปนี้:
- ความยาวของกลีบเลี้ยง
- ความกว้างของกลีบเลี้ยง
- ความยาวของกลีบดอก
- ความกว้างกลีบ
จากข้อมูลนี้ คุณสามารถกำหนดค่าคงที่ที่เป็นประโยชน์บางประการสำหรับการแยกวิเคราะห์ข้อมูล:
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']
ถัดไป ดาวน์โหลดและแยกวิเคราะห์ชุดข้อมูล Iris โดยใช้ Keras และ Pandas โปรดทราบว่าคุณเก็บชุดข้อมูลที่แตกต่างกันสำหรับการฝึกอบรมและการทดสอบ
train_path = tf.keras.utils.get_file(
"iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
"iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")
train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv 16384/2194 [================================================================================================================================================================================================================================] - 0s 0us/step Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv 16384/573 [=========================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================================] - 0s 0us/step
คุณสามารถตรวจสอบข้อมูลของคุณเพื่อดูว่าคุณมีคอลัมน์คุณลักษณะโฟลตสี่คอลัมน์และป้ายกำกับ int32 หนึ่งรายการ
train.head()
สำหรับแต่ละชุดข้อมูล ให้แยกป้ายกำกับ ซึ่งโมเดลจะได้รับการฝึกให้คาดการณ์
train_y = train.pop('Species')
test_y = test.pop('Species')
# The label column has now been removed from the features.
train.head()
ภาพรวมของการเขียนโปรแกรมด้วย Estimators
เมื่อคุณได้ตั้งค่าข้อมูลแล้ว คุณสามารถกำหนดแบบจำลองโดยใช้เครื่องมือประมาณการ TensorFlow Estimator คือคลาสใดๆ ที่ได้มาจาก tf.estimator.Estimator
TensorFlow จัดเตรียมคอลเลกชันของ tf.estimator
(เช่น LinearRegressor
) เพื่อนำอัลกอริธึม ML ทั่วไปไปใช้ นอกเหนือจากนั้น คุณอาจเขียน เครื่องมือประมาณการแบบกำหนด เองของคุณเอง ขอแนะนำให้ใช้ตัวประมาณการที่สร้างไว้ล่วงหน้าเมื่อเพิ่งเริ่มต้น
ในการเขียนโปรแกรม TensorFlow โดยยึดตามเครื่องมือประมาณการที่สร้างไว้ล่วงหน้า คุณต้องดำเนินการดังต่อไปนี้:
- สร้างฟังก์ชันอินพุตอย่างน้อยหนึ่งฟังก์ชัน
- กำหนดคอลัมน์คุณลักษณะของโมเดล
- สร้างอินสแตนซ์ของ Estimator โดยระบุคอลัมน์คุณลักษณะและไฮเปอร์พารามิเตอร์ต่างๆ
- เรียกใช้เมธอดอย่างน้อยหนึ่งวิธีบนออบเจ็กต์ Estimator โดยส่งฟังก์ชันอินพุตที่เหมาะสมเป็นแหล่งข้อมูล
เรามาดูกันว่างานเหล่านั้นถูกนำไปใช้สำหรับการจำแนก Iris อย่างไร
สร้างฟังก์ชันอินพุต
คุณต้องสร้างฟังก์ชันอินพุตเพื่อจัดหาข้อมูลสำหรับการฝึกอบรม การประเมิน และการทำนาย
ฟังก์ชันอินพุต คือฟังก์ชันที่ส่งคืนอ็อบเจ็กต์ tf.data.Dataset
ซึ่งส่งออกทูเพิลสององค์ประกอบต่อไปนี้:
-
features
- พจนานุกรม Python ที่:- แต่ละคีย์คือชื่อของฟีเจอร์
- แต่ละค่าคืออาร์เรย์ที่มีค่าทั้งหมดของคุณลักษณะนั้น
-
label
- อาร์เรย์ที่มีค่าของ ป้ายกำกับ สำหรับทุกตัวอย่าง
เพื่อแสดงรูปแบบของฟังก์ชันอินพุต ต่อไปนี้คือการใช้งานอย่างง่าย:
def input_evaluation_set():
features = {'SepalLength': np.array([6.4, 5.0]),
'SepalWidth': np.array([2.8, 2.3]),
'PetalLength': np.array([5.6, 3.3]),
'PetalWidth': np.array([2.2, 1.0])}
labels = np.array([2, 1])
return features, labels
ฟังก์ชันป้อนข้อมูลของคุณอาจสร้างพจนานุกรม features
และรายการ label
ตามที่คุณต้องการ อย่างไรก็ตาม ขอแนะนำให้ใช้ Dataset API ของ TensorFlow ซึ่งสามารถแยกวิเคราะห์ข้อมูลได้ทุกประเภท
Dataset API สามารถจัดการกรณีทั่วไปได้มากมายสำหรับคุณ ตัวอย่างเช่น เมื่อใช้ Dataset API คุณสามารถอ่านบันทึกจากคอลเล็กชันไฟล์ขนาดใหญ่แบบขนานและรวมเป็นสตรีมเดียวได้อย่างง่ายดาย
เพื่อให้ง่ายในตัวอย่างนี้ คุณจะต้องโหลดข้อมูลด้วย pandas และสร้างไพพ์ไลน์อินพุตจากข้อมูลในหน่วยความจำนี้:
def input_fn(features, labels, training=True, batch_size=256):
"""An input function for training or evaluating"""
# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
# Shuffle and repeat if you are in training mode.
if training:
dataset = dataset.shuffle(1000).repeat()
return dataset.batch(batch_size)
กำหนดคอลัมน์คุณลักษณะ
คอลัมน์คุณลักษณะ เป็นอ็อบเจ็กต์ที่อธิบายว่าโมเดลควรใช้ข้อมูลดิบอินพุตจากพจนานุกรมคุณลักษณะอย่างไร เมื่อคุณสร้างแบบจำลองประมาณการ คุณจะต้องส่งรายการคอลัมน์คุณลักษณะที่อธิบายคุณลักษณะแต่ละอย่างที่คุณต้องการให้แบบจำลองใช้ โมดูล tf.feature_column
มีตัวเลือกมากมายสำหรับการแสดงข้อมูลไปยังโมเดล
สำหรับ Iris คุณลักษณะดิบ 4 รายการเป็นค่าตัวเลข ดังนั้น คุณจะต้องสร้างรายการคอลัมน์คุณลักษณะเพื่อบอกให้โมเดล Estimator แสดงคุณลักษณะแต่ละอย่างของคุณลักษณะทั้งสี่เป็นค่าทศนิยม 32 บิต ดังนั้น รหัสสำหรับสร้างคอลัมน์คุณลักษณะคือ:
# Feature columns describe how to use the input.
my_feature_columns = []
for key in train.keys():
my_feature_columns.append(tf.feature_column.numeric_column(key=key))
คอลัมน์คุณลักษณะอาจซับซ้อนกว่าที่แสดงที่นี่มาก คุณสามารถอ่านเพิ่มเติมเกี่ยวกับคอลัมน์คุณลักษณะได้ใน คู่มือ นี้
ตอนนี้ คุณมีคำอธิบายว่าคุณต้องการให้โมเดลแสดงคุณลักษณะดิบอย่างไร คุณสามารถสร้างตัวประมาณได้
ยกตัวอย่างตัวประมาณ
ปัญหาม่านตาเป็นปัญหาการจำแนกประเภทคลาสสิก โชคดีที่ TensorFlow มีตัวประมาณการลักษณนามที่สร้างไว้ล่วงหน้าหลายตัว ซึ่งรวมถึง:
-
tf.estimator.DNNClassifier
สำหรับโมเดลเชิงลึกที่ทำการจำแนกประเภทหลายคลาส -
tf.estimator.DNNLinearCombinedClassifier
สำหรับรุ่นกว้างและลึก -
tf.estimator.LinearClassifier
สำหรับตัวแยกประเภทตามแบบจำลองเชิงเส้น
สำหรับปัญหาม่านตา tf.estimator.DNNClassifier
ดูเหมือนจะเป็นตัวเลือกที่ดีที่สุด นี่คือวิธีที่คุณยกตัวอย่างเครื่องมือประมาณนี้:
# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
# Two hidden layers of 30 and 10 nodes respectively.
hidden_units=[30, 10],
# The model must choose between 3 classes.
n_classes=3)
INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpxdgumb2t INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpxdgumb2t', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
ฝึกฝน ประเมิน และทำนาย
ในตอนนี้ คุณมีออบเจ็กต์ Estimator แล้ว คุณสามารถเรียกใช้เมธอดเพื่อทำสิ่งต่อไปนี้ได้:
- ฝึกโมเดล.
- ประเมินแบบจำลองการฝึกอบรม
- ใช้แบบจำลองที่ได้รับการฝึกอบรมมาในการทำนาย
ฝึกโมเดล
ฝึกโมเดลโดยเรียกวิธี train
ของ Estimator ดังนี้
# Train the Model.
classifier.train(
input_fn=lambda: input_fn(train, train_y, training=True),
steps=5000)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/adagrad.py:84: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpxdgumb2t/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 1.6787335, step = 0 INFO:tensorflow:global_step/sec: 305.625 INFO:tensorflow:loss = 1.1945828, step = 100 (0.328 sec) INFO:tensorflow:global_step/sec: 375.48 INFO:tensorflow:loss = 1.0221117, step = 200 (0.266 sec) INFO:tensorflow:global_step/sec: 376.21 INFO:tensorflow:loss = 0.9240805, step = 300 (0.266 sec) INFO:tensorflow:global_step/sec: 377.968 INFO:tensorflow:loss = 0.85917354, step = 400 (0.265 sec) INFO:tensorflow:global_step/sec: 376.297 INFO:tensorflow:loss = 0.81545967, step = 500 (0.265 sec) INFO:tensorflow:global_step/sec: 367.549 INFO:tensorflow:loss = 0.7771524, step = 600 (0.272 sec) INFO:tensorflow:global_step/sec: 378.887 INFO:tensorflow:loss = 0.74371505, step = 700 (0.264 sec) INFO:tensorflow:global_step/sec: 379.26 INFO:tensorflow:loss = 0.717993, step = 800 (0.264 sec) INFO:tensorflow:global_step/sec: 370.102 INFO:tensorflow:loss = 0.6952705, step = 900 (0.270 sec) INFO:tensorflow:global_step/sec: 373.034 INFO:tensorflow:loss = 0.68044865, step = 1000 (0.268 sec) INFO:tensorflow:global_step/sec: 372.193 INFO:tensorflow:loss = 0.65181077, step = 1100 (0.269 sec) INFO:tensorflow:global_step/sec: 339.238 INFO:tensorflow:loss = 0.6319051, step = 1200 (0.295 sec) INFO:tensorflow:global_step/sec: 334.252 INFO:tensorflow:loss = 0.63433766, step = 1300 (0.299 sec) INFO:tensorflow:global_step/sec: 343.436 INFO:tensorflow:loss = 0.61748827, step = 1400 (0.291 sec) INFO:tensorflow:global_step/sec: 346.575 INFO:tensorflow:loss = 0.606356, step = 1500 (0.288 sec) INFO:tensorflow:global_step/sec: 351.362 INFO:tensorflow:loss = 0.59807724, step = 1600 (0.285 sec) INFO:tensorflow:global_step/sec: 366.628 INFO:tensorflow:loss = 0.5832784, step = 1700 (0.273 sec) INFO:tensorflow:global_step/sec: 367.034 INFO:tensorflow:loss = 0.5664347, step = 1800 (0.273 sec) INFO:tensorflow:global_step/sec: 372.339 INFO:tensorflow:loss = 0.5684726, step = 1900 (0.268 sec) INFO:tensorflow:global_step/sec: 368.957 INFO:tensorflow:loss = 0.56011164, step = 2000 (0.271 sec) INFO:tensorflow:global_step/sec: 373.128 INFO:tensorflow:loss = 0.5483226, step = 2100 (0.268 sec) INFO:tensorflow:global_step/sec: 377.334 INFO:tensorflow:loss = 0.5447233, step = 2200 (0.265 sec) INFO:tensorflow:global_step/sec: 370.421 INFO:tensorflow:loss = 0.5358016, step = 2300 (0.270 sec) INFO:tensorflow:global_step/sec: 367.076 INFO:tensorflow:loss = 0.53145075, step = 2400 (0.273 sec) INFO:tensorflow:global_step/sec: 373.596 INFO:tensorflow:loss = 0.50931674, step = 2500 (0.268 sec) INFO:tensorflow:global_step/sec: 368.939 INFO:tensorflow:loss = 0.5253717, step = 2600 (0.271 sec) INFO:tensorflow:global_step/sec: 354.814 INFO:tensorflow:loss = 0.52558273, step = 2700 (0.282 sec) INFO:tensorflow:global_step/sec: 372.243 INFO:tensorflow:loss = 0.51422054, step = 2800 (0.269 sec) INFO:tensorflow:global_step/sec: 366.891 INFO:tensorflow:loss = 0.49747026, step = 2900 (0.272 sec) INFO:tensorflow:global_step/sec: 370.952 INFO:tensorflow:loss = 0.49974674, step = 3000 (0.270 sec) INFO:tensorflow:global_step/sec: 364.158 INFO:tensorflow:loss = 0.4978399, step = 3100 (0.275 sec) INFO:tensorflow:global_step/sec: 365.383 INFO:tensorflow:loss = 0.5030147, step = 3200 (0.273 sec) INFO:tensorflow:global_step/sec: 366.791 INFO:tensorflow:loss = 0.4772169, step = 3300 (0.273 sec) INFO:tensorflow:global_step/sec: 372.438 INFO:tensorflow:loss = 0.46993533, step = 3400 (0.269 sec) INFO:tensorflow:global_step/sec: 371.25 INFO:tensorflow:loss = 0.47242266, step = 3500 (0.269 sec) INFO:tensorflow:global_step/sec: 369.725 INFO:tensorflow:loss = 0.46513358, step = 3600 (0.271 sec) INFO:tensorflow:global_step/sec: 371.002 INFO:tensorflow:loss = 0.4762191, step = 3700 (0.270 sec) INFO:tensorflow:global_step/sec: 369.304 INFO:tensorflow:loss = 0.44923267, step = 3800 (0.271 sec) INFO:tensorflow:global_step/sec: 369.344 INFO:tensorflow:loss = 0.45467538, step = 3900 (0.271 sec) INFO:tensorflow:global_step/sec: 375.58 INFO:tensorflow:loss = 0.46056622, step = 4000 (0.266 sec) INFO:tensorflow:global_step/sec: 347.461 INFO:tensorflow:loss = 0.4489282, step = 4100 (0.288 sec) INFO:tensorflow:global_step/sec: 368.435 INFO:tensorflow:loss = 0.45647347, step = 4200 (0.272 sec) INFO:tensorflow:global_step/sec: 369.159 INFO:tensorflow:loss = 0.4444633, step = 4300 (0.271 sec) INFO:tensorflow:global_step/sec: 371.995 INFO:tensorflow:loss = 0.44425523, step = 4400 (0.269 sec) INFO:tensorflow:global_step/sec: 373.586 INFO:tensorflow:loss = 0.44025964, step = 4500 (0.268 sec) INFO:tensorflow:global_step/sec: 373.136 INFO:tensorflow:loss = 0.44341013, step = 4600 (0.269 sec) INFO:tensorflow:global_step/sec: 369.751 INFO:tensorflow:loss = 0.42856425, step = 4700 (0.269 sec) INFO:tensorflow:global_step/sec: 364.219 INFO:tensorflow:loss = 0.44144967, step = 4800 (0.275 sec) INFO:tensorflow:global_step/sec: 372.675 INFO:tensorflow:loss = 0.42951846, step = 4900 (0.268 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 5000... INFO:tensorflow:Saving checkpoints for 5000 into /tmp/tmpxdgumb2t/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 5000... INFO:tensorflow:Loss for final step: 0.42713496. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifierV2 at 0x7fad05e33910>
โปรดทราบว่าคุณรวมการเรียก input_fn
ของคุณใน lambda
เพื่อดักจับอาร์กิวเมนต์ในขณะที่จัดเตรียมฟังก์ชันอินพุตที่ไม่รับอาร์กิวเมนต์ ตามที่ตัวประมาณการคาดไว้ อาร์กิวเมนต์ steps
จะบอกวิธีการหยุดการฝึกหลังจากฝึกหลายขั้นตอน
ประเมินแบบจำลองที่ได้รับการฝึกอบรม
เมื่อโมเดลได้รับการฝึกอบรมแล้ว คุณสามารถดูสถิติเกี่ยวกับประสิทธิภาพของโมเดลได้ บล็อกโค้ดต่อไปนี้จะประเมินความถูกต้องของแบบจำลองที่ได้รับการฝึกอบรมจากข้อมูลการทดสอบ:
eval_result = classifier.evaluate(
input_fn=lambda: input_fn(test, test_y, training=False))
print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-26T06:41:28 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpxdgumb2t/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Inference Time : 0.40087s INFO:tensorflow:Finished evaluation at 2022-01-26-06:41:28 INFO:tensorflow:Saving dict for global step 5000: accuracy = 0.8666667, average_loss = 0.49953422, global_step = 5000, loss = 0.49953422 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 5000: /tmp/tmpxdgumb2t/model.ckpt-5000 Test set accuracy: 0.867
ไม่เหมือนกับวิธีการเรียก train
คุณไม่ได้ผ่านอาร์กิวเมนต์ steps
เพื่อประเมิน input_fn
สำหรับ eval ให้ผลข้อมูลเพียง ยุค เดียวเท่านั้น
พจนานุกรม eval_result
ยังมี average_loss
(การสูญเสียเฉลี่ยต่อตัวอย่าง) การ loss
(การสูญเสียเฉลี่ยต่อชุดย่อย) และมูลค่าของ global_step
ของตัวประมาณ (จำนวนการทำซ้ำการฝึกอบรมที่ได้รับ)
การคาดคะเน (อนุมาน) จากตัวแบบฝึก
ตอนนี้คุณมีรูปแบบการฝึกอบรมที่ให้ผลการประเมินที่ดี ตอนนี้คุณสามารถใช้แบบจำลองที่ได้รับการฝึกอบรมมาเพื่อทำนายสายพันธุ์ของดอกไอริสตามการวัดที่ไม่ได้ติดป้ายกำกับ เช่นเดียวกับการฝึกอบรมและการประเมิน คุณทำการคาดคะเนโดยใช้การเรียกใช้ฟังก์ชันเดียว:
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
'SepalLength': [5.1, 5.9, 6.9],
'SepalWidth': [3.3, 3.0, 3.1],
'PetalLength': [1.7, 4.2, 5.4],
'PetalWidth': [0.5, 1.5, 2.1],
}
def input_fn(features, batch_size=256):
"""An input function for prediction."""
# Convert the inputs to a Dataset without labels.
return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)
predictions = classifier.predict(
input_fn=lambda: input_fn(predict_x))
วิธี predict
จะส่งคืน Python iterable ให้ผลพจนานุกรมของผลลัพธ์การทำนายสำหรับแต่ละตัวอย่าง รหัสต่อไปนี้พิมพ์คำทำนายและความน่าจะเป็นบางส่วน:
for pred_dict, expec in zip(predictions, expected):
class_id = pred_dict['class_ids'][0]
probability = pred_dict['probabilities'][class_id]
print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
SPECIES[class_id], 100 * probability, expec))
INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmpxdgumb2t/model.ckpt-5000 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. Prediction is "Setosa" (84.4%), expected "Setosa" Prediction is "Versicolor" (49.3%), expected "Versicolor" Prediction is "Virginica" (57.7%), expected "Virginica"