Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Tổng quat
Công cụ ước tính TensorFlow được hỗ trợ trong TensorFlow và có thể được tạo từ các mô hình tf.keras
mới và hiện có. Hướng dẫn này chứa một ví dụ đầy đủ, tối thiểu về quy trình đó.
Thành lập
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
Tạo một mô hình Keras đơn giản.
Trong Keras, bạn lắp ráp các lớp để xây dựng mô hình . Một mô hình (thường) là một đồ thị của các lớp. Loại mô hình phổ biến nhất là một chồng các lớp: mô hình tf.keras.Sequential
.
Để xây dựng một mạng đơn giản, được kết nối đầy đủ (tức là perceptron nhiều lớp):
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(3)
])
Biên dịch mô hình và nhận một bản tóm tắt.
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer='adam')
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= dense (Dense) (None, 16) 80 dropout (Dropout) (None, 16) 0 dense_1 (Dense) (None, 3) 51 ================================================================= Total params: 131 Trainable params: 131 Non-trainable params: 0 _________________________________________________________________
Tạo một hàm đầu vào
Sử dụng API tập dữ liệu để chia tỷ lệ thành tập dữ liệu lớn hoặc đào tạo nhiều thiết bị.
Người ước tính cần kiểm soát thời gian và cách thức xây dựng đường ống đầu vào của họ. Để cho phép điều này, chúng yêu cầu một "Hàm đầu vào" hoặc input_fn
. Estimator
sẽ gọi hàm này mà không có đối số. input_fn
phải trả về tf.data.Dataset
.
def input_fn():
split = tfds.Split.TRAIN
dataset = tfds.load('iris', split=split, as_supervised=True)
dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
dataset = dataset.batch(32).repeat()
return dataset
Kiểm tra input_fn
của bạn
for features_batch, labels_batch in input_fn().take(1):
print(features_batch)
print(labels_batch)
{'dense_input': <tf.Tensor: shape=(32, 4), dtype=float32, numpy= array([[5.1, 3.4, 1.5, 0.2], [7.7, 3. , 6.1, 2.3], [5.7, 2.8, 4.5, 1.3], [6.8, 3.2, 5.9, 2.3], [5.2, 3.4, 1.4, 0.2], [5.6, 2.9, 3.6, 1.3], [5.5, 2.6, 4.4, 1.2], [5.5, 2.4, 3.7, 1. ], [4.6, 3.4, 1.4, 0.3], [7.7, 2.8, 6.7, 2. ], [7. , 3.2, 4.7, 1.4], [4.6, 3.2, 1.4, 0.2], [6.5, 3. , 5.2, 2. ], [5.5, 4.2, 1.4, 0.2], [5.4, 3.9, 1.3, 0.4], [5. , 3.5, 1.3, 0.3], [5.1, 3.8, 1.5, 0.3], [4.8, 3. , 1.4, 0.1], [6.5, 3. , 5.8, 2.2], [7.6, 3. , 6.6, 2.1], [6.7, 3.3, 5.7, 2.1], [7.9, 3.8, 6.4, 2. ], [6.7, 3. , 5.2, 2.3], [5.8, 4. , 1.2, 0.2], [6.3, 2.5, 5. , 1.9], [5. , 3. , 1.6, 0.2], [6.9, 3.1, 5.1, 2.3], [6.1, 3. , 4.6, 1.4], [5.8, 2.7, 4.1, 1. ], [5.2, 2.7, 3.9, 1.4], [6.7, 3. , 5. , 1.7], [5.7, 2.6, 3.5, 1. ]], dtype=float32)>} tf.Tensor([0 2 1 2 0 1 1 1 0 2 1 0 2 0 0 0 0 0 2 2 2 2 2 0 2 0 2 1 1 1 1 1], shape=(32,), dtype=int64)
Tạo Công cụ ước tính từ mô hình tf.keras.
Một tf.keras.Model
có thể được đào tạo với API tf.estimator
bằng cách chuyển đổi mô hình thành đối tượng tf.estimator.Estimator
với tf.keras.estimator.model_to_estimator
.
import tempfile
model_dir = tempfile.mkdtemp()
keras_estimator = tf.keras.estimator.model_to_estimator(
keras_model=model, model_dir=model_dir)
INFO:tensorflow:Using default config. INFO:tensorflow:Using default config. INFO:tensorflow:Using the Keras model provided. INFO:tensorflow:Using the Keras model provided. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:450: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model. warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and ' INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp2jzrjbqb', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp2jzrjbqb', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
Đào tạo và đánh giá người lập dự toán.
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:397: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version. Instructions for updating: Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp2jzrjbqb/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmp2jzrjbqb/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={}) INFO:tensorflow:Warm-starting from: /tmp/tmp2jzrjbqb/keras/keras_model.ckpt INFO:tensorflow:Warm-starting from: /tmp/tmp2jzrjbqb/keras/keras_model.ckpt INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES. INFO:tensorflow:Warm-started 4 variables. INFO:tensorflow:Warm-started 4 variables. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp2jzrjbqb/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp2jzrjbqb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:loss = 3.2731433, step = 0 INFO:tensorflow:loss = 3.2731433, step = 0 INFO:tensorflow:global_step/sec: 19.6463 INFO:tensorflow:global_step/sec: 19.6463 INFO:tensorflow:loss = 1.012466, step = 100 (5.092 sec) INFO:tensorflow:loss = 1.012466, step = 100 (5.092 sec) INFO:tensorflow:global_step/sec: 19.705 INFO:tensorflow:global_step/sec: 19.705 INFO:tensorflow:loss = 0.9225232, step = 200 (5.075 sec) INFO:tensorflow:loss = 0.9225232, step = 200 (5.075 sec) INFO:tensorflow:global_step/sec: 19.9236 INFO:tensorflow:global_step/sec: 19.9236 INFO:tensorflow:loss = 0.8686823, step = 300 (5.019 sec) INFO:tensorflow:loss = 0.8686823, step = 300 (5.019 sec) INFO:tensorflow:global_step/sec: 19.8862 INFO:tensorflow:global_step/sec: 19.8862 INFO:tensorflow:loss = 0.6412657, step = 400 (5.029 sec) INFO:tensorflow:loss = 0.6412657, step = 400 (5.029 sec) INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 500... INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp2jzrjbqb/model.ckpt. INFO:tensorflow:Saving checkpoints for 500 into /tmp/tmp2jzrjbqb/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 500... INFO:tensorflow:Loss for final step: 0.65391386. INFO:tensorflow:Loss for final step: 0.65391386. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training_v1.py:2057: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically. updates = self.state_updates INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2022-01-26T06:39:31 INFO:tensorflow:Starting evaluation at 2022-01-26T06:39:31 INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmp/tmp2jzrjbqb/model.ckpt-500 INFO:tensorflow:Restoring parameters from /tmp/tmp2jzrjbqb/model.ckpt-500 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [1/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [2/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [3/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [4/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [5/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [6/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [7/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [8/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [9/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Evaluation [10/10] INFO:tensorflow:Inference Time : 0.63967s INFO:tensorflow:Inference Time : 0.63967s INFO:tensorflow:Finished evaluation at 2022-01-26-06:39:31 INFO:tensorflow:Finished evaluation at 2022-01-26-06:39:31 INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.6503415 INFO:tensorflow:Saving dict for global step 500: global_step = 500, loss = 0.6503415 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp2jzrjbqb/model.ckpt-500 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 500: /tmp/tmp2jzrjbqb/model.ckpt-500 Eval result: {'loss': 0.6503415, 'global_step': 500}