برآوردگرها

مشاهده در TensorFlow.org در Google Colab اجرا شود مشاهده منبع در GitHub دانلود دفترچه یادداشت

این سند tf.estimator - یک API سطح بالا tf.estimator را معرفی می کند. برآوردگرها اقدامات زیر را در بر می گیرند:

  • آموزش
  • ارزیابی
  • پیش بینی
  • صادرات برای خدمت

TensorFlow چندین برآوردگر از پیش ساخته را پیاده سازی می کند. برآوردگرهای سفارشی هنوز پشتیبانی می شوند، اما عمدتاً به عنوان معیار سازگاری با عقب. برآوردگرهای سفارشی نباید برای کد جدید استفاده شوند. همه برآوردگرها - از پیش ساخته شده یا سفارشی - کلاس هایی بر اساس کلاس tf.estimator.Estimator هستند.

برای مثال سریع، آموزش‌های Estimator را امتحان کنید. برای یک نمای کلی از طراحی API، کاغذ سفید را بررسی کنید.

برپایی

pip install -U tensorflow_datasets
import tempfile
import os

import tensorflow as tf
import tensorflow_datasets as tfds

مزایای

مشابه tf.keras.Model ، estimator یک انتزاع در سطح مدل است. tf.estimator برخی از قابلیت‌هایی را فراهم می‌کند که در حال حاضر هنوز در حال توسعه برای tf.keras هستند. اینها هستند:

  • آموزش مبتنی بر سرور پارامتر
  • ادغام کامل TFX

قابلیت های برآوردگرها

برآوردگرها مزایای زیر را ارائه می دهند:

  • شما می توانید مدل های مبتنی بر برآوردگر را بر روی یک میزبان محلی یا در یک محیط چند سرور توزیع شده بدون تغییر مدل خود اجرا کنید. علاوه بر این، می‌توانید مدل‌های مبتنی بر برآوردگر را بر روی CPU، GPU یا TPU بدون کدگذاری مجدد مدل خود اجرا کنید.
  • تخمین‌گرها یک حلقه آموزشی توزیع‌شده ایمن را ارائه می‌کنند که نحوه و زمان انجام موارد زیر را کنترل می‌کند:
    • بارگذاری داده ها
    • رسیدگی به استثناها
    • فایل های ایست بازرسی ایجاد کنید و از خرابی ها بازیابی کنید
    • خلاصه ها را برای TensorBoard ذخیره کنید

هنگام نوشتن برنامه با برآوردگرها، باید خط لوله ورودی داده را از مدل جدا کنید. این جداسازی آزمایش ها را با مجموعه داده های مختلف ساده می کند.

استفاده از برآوردگرهای از پیش ساخته شده

برآوردگرهای از پیش ساخته شده شما را قادر می‌سازد تا در سطح مفهومی بسیار بالاتری نسبت به APIهای پایه TensorFlow کار کنید. دیگر لازم نیست نگران ایجاد نمودار محاسباتی یا جلسات باشید، زیرا برآوردگرها تمام "لوله کشی" را برای شما انجام می دهند. علاوه بر این، برآوردگرهای از پیش ساخته شده به شما اجازه می‌دهند تا با ایجاد حداقل تغییرات کد، معماری‌های مدل‌های مختلف را آزمایش کنید. به عنوان مثال، tf.estimator.DNNClassifier یک کلاس Estimator از پیش ساخته است که مدل‌های طبقه‌بندی را بر اساس شبکه‌های عصبی متراکم و پیش‌خور آموزش می‌دهد.

یک برنامه TensorFlow با تکیه بر یک برآوردگر از پیش ساخته شده معمولاً شامل چهار مرحله زیر است:

1. یک توابع ورودی بنویسید

برای مثال، ممکن است یک تابع برای وارد کردن مجموعه آموزشی و یک تابع دیگر برای وارد کردن مجموعه آزمایشی ایجاد کنید. تخمین‌گران انتظار دارند ورودی‌های آنها به صورت یک جفت شی فرم‌بندی شوند:

  • فرهنگ لغتی که در آن کلیدها نام ویژگی‌ها و مقادیر Tensor (یا SparseTensors) هستند که حاوی داده‌های ویژگی مربوطه است.
  • یک تانسور حاوی یک یا چند برچسب

input_fn باید یک tf.data.Dataset که جفت‌هایی را در آن قالب به دست می‌دهد.

برای مثال، کد زیر یک tf.data.Dataset را از فایل train.csv مجموعه داده تایتانیک می‌سازد:

def train_input_fn():
  titanic_file = tf.keras.utils.get_file("train.csv", "https://storage.googleapis.com/tf-datasets/titanic/train.csv")
  titanic = tf.data.experimental.make_csv_dataset(
      titanic_file, batch_size=32,
      label_name="survived")
  titanic_batches = (
      titanic.cache().repeat().shuffle(500)
      .prefetch(tf.data.AUTOTUNE))
  return titanic_batches

input_fn در یک tf.Graph اجرا می‌شود و همچنین می‌تواند مستقیماً یک جفت (features_dics, labels) حاوی تانسورهای گراف را برگرداند، اما خارج از موارد ساده مانند بازگشت ثابت‌ها مستعد خطا است.

2. ستون های ویژگی را تعریف کنید.

هر tf.feature_column یک نام ویژگی، نوع آن و هرگونه پیش پردازش ورودی را مشخص می کند.

به عنوان مثال، قطعه زیر سه ستون ویژگی ایجاد می کند.

  • اولین مورد از ویژگی age به طور مستقیم به عنوان ورودی ممیز شناور استفاده می کند.
  • دومی از ویژگی class به عنوان ورودی دسته بندی استفاده می کند.
  • سومی از embark_town به عنوان ورودی طبقه‌بندی استفاده می‌کند، اما از hashing trick برای اجتناب از نیاز به شمارش گزینه‌ها و تنظیم تعداد گزینه‌ها استفاده می‌کند.

برای اطلاعات بیشتر، آموزش ستون های ویژگی را بررسی کنید.

age = tf.feature_column.numeric_column('age')
cls = tf.feature_column.categorical_column_with_vocabulary_list('class', ['First', 'Second', 'Third']) 
embark = tf.feature_column.categorical_column_with_hash_bucket('embark_town', 32)

3. برآوردگر از پیش ساخته شده مربوطه را نمونه سازی کنید.

به عنوان مثال، در اینجا نمونه‌ای از یک برآوردگر از پیش ساخته شده به نام LinearClassifier :

model_dir = tempfile.mkdtemp()
model = tf.estimator.LinearClassifier(
    model_dir=model_dir,
    feature_columns=[embark, cls, age],
    n_classes=2
)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpl24pp3cp', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

برای اطلاعات بیشتر، می‌توانید به آموزش طبقه‌بندی خطی بروید.

4. یک روش آموزشی، ارزیابی یا استنتاج فراخوانی کنید.

همه برآوردگرها روش های train ، evaluate و predict را ارائه می دهند.

model = model.train(input_fn=train_input_fn, steps=100)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/training/training_util.py:236: Variable.initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.
Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/base_layer_v1.py:1684: UserWarning: `layer.add_variable` is deprecated and will be removed in a future version. Please use `layer.add_weight` method instead.
  warnings.warn('`layer.add_variable` is deprecated and '
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/ftrl.py:147: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpl24pp3cp/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 100...
INFO:tensorflow:Saving checkpoints for 100 into /tmp/tmpl24pp3cp/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 100...
INFO:tensorflow:Loss for final step: 0.6319582.
2021-09-22 20:49:10.453286: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
result = model.evaluate(train_input_fn, steps=10)

for key, value in result.items():
  print(key, ":", value)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:11
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.74609s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:12
INFO:tensorflow:Saving dict for global step 100: accuracy = 0.734375, accuracy_baseline = 0.640625, auc = 0.7373913, auc_precision_recall = 0.64306235, average_loss = 0.563341, global_step = 100, label/mean = 0.359375, loss = 0.563341, precision = 0.734375, prediction/mean = 0.3463129, recall = 0.40869564
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 100: /tmp/tmpl24pp3cp/model.ckpt-100
accuracy : 0.734375
accuracy_baseline : 0.640625
auc : 0.7373913
auc_precision_recall : 0.64306235
average_loss : 0.563341
label/mean : 0.359375
loss : 0.563341
precision : 0.734375
prediction/mean : 0.3463129
recall : 0.40869564
global_step : 100
2021-09-22 20:49:12.168629: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
for pred in model.predict(train_input_fn):
  for key, value in pred.items():
    print(key, ":", value)
  break
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpl24pp3cp/model.ckpt-100
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
logits : [-1.5173098]
logistic : [0.17985801]
probabilities : [0.820142   0.17985801]
class_ids : [0]
classes : [b'0']
all_class_ids : [0 1]
all_classes : [b'0' b'1']
2021-09-22 20:49:13.076528: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

مزایای برآوردگرهای از پیش ساخته شده

برآوردگرهای از پیش ساخته بهترین شیوه ها را رمزگذاری می کنند و مزایای زیر را ارائه می دهند:

  • بهترین روش‌ها برای تعیین محل اجرای بخش‌های مختلف نمودار محاسباتی، اجرای استراتژی‌ها بر روی یک ماشین یا یک خوشه.
  • بهترین شیوه ها برای نوشتن رویداد (خلاصه) و خلاصه های مفید جهانی.

اگر از برآوردگرهای از پیش ساخته شده استفاده نمی کنید، باید ویژگی های قبلی را خودتان پیاده سازی کنید.

برآوردگرهای سفارشی

قلب هر برآوردگر - چه از پیش ساخته شده باشد و چه سفارشی - تابع مدل آن است، model_fn ، که روشی است که نمودارهایی را برای آموزش، ارزیابی و پیش‌بینی می‌سازد. هنگامی که از یک برآوردگر از پیش ساخته شده استفاده می کنید، شخص دیگری قبلاً تابع مدل را پیاده سازی کرده است. هنگام تکیه بر یک برآوردگر سفارشی، باید تابع مدل را خودتان بنویسید.

یک برآوردگر از مدل Keras ایجاد کنید

می‌توانید مدل‌های Keras موجود را با tf.keras.estimator.model_to_estimator به برآوردگر تبدیل کنید. اگر می‌خواهید کد مدل خود را مدرن کنید، این کار مفید است، اما خط لوله آموزشی شما همچنان به برآوردگر نیاز دارد.

یک مدل Keras MobileNet V2 را نمونه‌سازی کنید و مدل را با بهینه‌ساز، ضرر و معیارها برای آموزش با آن کامپایل کنید:

keras_mobilenet_v2 = tf.keras.applications.MobileNetV2(
    input_shape=(160, 160, 3), include_top=False)
keras_mobilenet_v2.trainable = False

estimator_model = tf.keras.Sequential([
    keras_mobilenet_v2,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(1)
])

# Compile the model
estimator_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    metrics=['accuracy'])
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step
9420800/9406464 [==============================] - 0s 0us/step

یک Estimator از مدل کامپایل شده Keras ایجاد کنید. حالت اولیه مدل کراس در Estimator ایجاد شده حفظ می شود:

est_mobilenet_v2 = tf.keras.estimator.model_to_estimator(keras_model=estimator_model)
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpmosnmied
INFO:tensorflow:Using the Keras model provided.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/backend.py:401: UserWarning: `tf.keras.backend.set_learning_phase` is deprecated and will be removed after 2020-10-11. To update it, simply pass a True/False value to the `training` argument of the `__call__` method of your layer or model.
  warnings.warn('`tf.keras.backend.set_learning_phase` is deprecated and '
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  category=CustomMaskWarning)
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpmosnmied', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}

با Estimator مشتق شده مانند هر Estimator دیگری رفتار کنید.

IMG_SIZE = 160  # All images will be resized to 160x160

def preprocess(image, label):
  image = tf.cast(image, tf.float32)
  image = (image/127.5) - 1
  image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
  return image, label
def train_input_fn(batch_size):
  data = tfds.load('cats_vs_dogs', as_supervised=True)
  train_data = data['train']
  train_data = train_data.map(preprocess).shuffle(500).batch(batch_size)
  return train_data

برای آموزش، تابع قطار برآوردگر را فراخوانی کنید:

est_mobilenet_v2.train(input_fn=lambda: train_input_fn(32), steps=50)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tmpmosnmied/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting from: /tmp/tmpmosnmied/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Warm-started 158 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6994096, step = 0
INFO:tensorflow:loss = 0.6994096, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmpmosnmied/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.68789804.
INFO:tensorflow:Loss for final step: 0.68789804.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b1c1e9890>

به طور مشابه، برای ارزیابی، تابع ارزیابی برآوردگر را فراخوانی کنید:

est_mobilenet_v2.evaluate(input_fn=lambda: train_input_fn(32), steps=10)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/training.py:2470: UserWarning: `Model.state_updates` will be removed in a future version. This property should not be used in TensorFlow 2.0, as `updates` are applied automatically.
  warnings.warn('`Model.state_updates` will be removed in a future version. '
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:36
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmpmosnmied/model.ckpt-50
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 3.89658s
INFO:tensorflow:Inference Time : 3.89658s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:39
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582
INFO:tensorflow:Saving dict for global step 50: accuracy = 0.525, global_step = 50, loss = 0.6723582
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 50: /tmp/tmpmosnmied/model.ckpt-50
{'accuracy': 0.525, 'loss': 0.6723582, 'global_step': 50}

برای جزئیات بیشتر، لطفاً به مستندات مربوط به tf.keras.estimator.model_to_estimator کنید.

ذخیره ایست های بازرسی مبتنی بر شی با Estimator

تخمین‌گرها به‌طور پیش‌فرض، نقاط بازرسی را با نام متغیرها ذخیره می‌کنند تا نمودار شی که در راهنمای Checkpoint توضیح داده شده است. tf.train.Checkpoint نقاط بازرسی مبتنی بر نام را می‌خواند، اما نام متغیرها ممکن است هنگام جابجایی بخش‌هایی از مدل به خارج از model_fn برآوردگر تغییر کند. برای سازگاری فوروارد، صرفه جویی در نقاط بازرسی مبتنی بر شی، آموزش یک مدل را در داخل تخمینگر و سپس استفاده از آن را در خارج از یکی آسان تر می کند.

import tensorflow.compat.v1 as tf_compat
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
def model_fn(features, labels, mode):
  net = Net()
  opt = tf.keras.optimizers.Adam(0.1)
  ckpt = tf.train.Checkpoint(step=tf_compat.train.get_global_step(),
                             optimizer=opt, net=net)
  with tf.GradientTape() as tape:
    output = net(features['x'])
    loss = tf.reduce_mean(tf.abs(output - features['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  return tf.estimator.EstimatorSpec(
    mode,
    loss=loss,
    train_op=tf.group(opt.apply_gradients(zip(gradients, variables)),
                      ckpt.step.assign_add(1)),
    # Tell the Estimator to save "ckpt" in an object-based format.
    scaffold=tf_compat.train.Scaffold(saver=ckpt))

tf.keras.backend.clear_session()
est = tf.estimator.Estimator(model_fn, './tf_estimator_example/')
est.train(toy_dataset, steps=10)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': './tf_estimator_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 4.659403, step = 0
INFO:tensorflow:loss = 4.659403, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into ./tf_estimator_example/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 39.58891.
INFO:tensorflow:Loss for final step: 39.58891.
<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f4b7c451fd0>

سپس tf.train.Checkpoint می تواند نقاط بازرسی برآوردگر را از model_dir خود بارگیری کند.

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(
  step=tf.Variable(1, dtype=tf.int64), optimizer=opt, net=net)
ckpt.restore(tf.train.latest_checkpoint('./tf_estimator_example/'))
ckpt.step.numpy()  # From est.train(..., steps=10)
10

SavedModels from Estimators

برآوردگرها SavedModels را از طریق tf.Estimator.export_saved_model صادر می کنند.

input_column = tf.feature_column.numeric_column("x")

estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])

def input_fn():
  return tf.data.Dataset.from_tensor_slices(
    ({"x": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)
estimator.train(input_fn)
INFO:tensorflow:Using default config.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp30_d7xz6
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp30_d7xz6', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:loss = 0.6931472, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 50...
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Saving checkpoints for 50 into /tmp/tmp30_d7xz6/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 50...
INFO:tensorflow:Loss for final step: 0.4022895.
INFO:tensorflow:Loss for final step: 0.4022895.
<tensorflow_estimator.python.estimator.canned.linear.LinearClassifierV2 at 0x7f4b1c10fd10>

برای ذخیره یک Estimator باید یک serving_input_receiver ایجاد کنید. این تابع بخشی از یک tf.Graph را می سازد که داده های خام دریافت شده توسط SavedModel را تجزیه می کند.

ماژول tf.estimator.export شامل توابعی است که به ساخت این receivers ها کمک می کند.

کد زیر یک گیرنده بر اساس feature_columns می‌سازد که بافرهای پروتکل feature_columns را می‌پذیرد که اغلب با tf- tf.Example استفاده می‌شوند.

tmpdir = tempfile.mkdtemp()

serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(
  tf.feature_column.make_parse_example_spec([input_column]))

estimator_base_path = os.path.join(tmpdir, 'from_estimator')
estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/signature_def_utils_impl.py:145: build_tensor_info (from tensorflow.python.saved_model.utils_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Classify: ['serving_default', 'classification']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Regress: ['regression']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Predict: ['predict']
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Train: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Signatures INCLUDED in export for Eval: None
INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50
INFO:tensorflow:Restoring parameters from /tmp/tmp30_d7xz6/model.ckpt-50
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:Assets added to graph.
INFO:tensorflow:No assets to write.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb
INFO:tensorflow:SavedModel written to: /tmp/tmpi_szzuj1/from_estimator/temp-1632343781/saved_model.pb

همچنین می توانید آن مدل را از پایتون بارگیری و اجرا کنید:

imported = tf.saved_model.load(estimator_path)

def predict(x):
  example = tf.train.Example()
  example.features.feature["x"].float_list.value.extend([x])
  return imported.signatures["predict"](
    examples=tf.constant([example.SerializeToString()]))
print(predict(1.5))
print(predict(3.5))
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[1]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'1']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.2974025]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.5738074]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.42619258, 0.5738074 ]], dtype=float32)>}
{'class_ids': <tf.Tensor: shape=(1, 1), dtype=int64, numpy=array([[0]])>, 'classes': <tf.Tensor: shape=(1, 1), dtype=string, numpy=array([[b'0']], dtype=object)>, 'all_classes': <tf.Tensor: shape=(1, 2), dtype=string, numpy=array([[b'0', b'1']], dtype=object)>, 'all_class_ids': <tf.Tensor: shape=(1, 2), dtype=int32, numpy=array([[0, 1]], dtype=int32)>, 'logits': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[-1.1919093]], dtype=float32)>, 'logistic': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[0.23291764]], dtype=float32)>, 'probabilities': <tf.Tensor: shape=(1, 2), dtype=float32, numpy=array([[0.7670824 , 0.23291762]], dtype=float32)>}

tf.estimator.export.build_raw_serving_input_receiver_fn به شما امکان می دهد توابع ورودی ایجاد کنید که تانسورهای خام را به جای tf.train.Example s می گیرند.

استفاده از tf.distribute.Strategy با برآوردگر (پشتیبانی محدود)

tf.estimator یک API آموزشی توزیع شده TensorFlow است که در ابتدا از رویکرد سرور پارامتر async پشتیبانی می کرد. tf.estimator اکنون tf.distribute.Strategy پشتیبانی می کند. اگر از tf.estimator استفاده می کنید، می توانید با تغییرات بسیار کمی در کد خود به آموزش توزیع شده تغییر دهید. با این کار، کاربران Estimator اکنون می توانند آموزش های توزیع شده همزمان را روی چندین GPU و چندین کارگر انجام دهند و همچنین از TPU ها استفاده کنند. با این حال، این پشتیبانی در Estimator محدود است. برای جزئیات بیشتر، بخش What's support now در زیر را بررسی کنید.

استفاده از tf.distribute.Strategy با برآوردگر کمی متفاوت از مورد Keras است. به جای استفاده از strategi.scope، اکنون شی strategy.scope را به RunConfig برای برآوردگر منتقل می‌کنید.

برای اطلاعات بیشتر می توانید به راهنمای آموزشی توزیع شده مراجعه کنید.

در اینجا یک قطعه کد است که این را با یک تخمینگر خطی از پیش ساخته شده و MirroredStrategy LinearRegressor نشان می دهد:

mirrored_strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(
    train_distribute=mirrored_strategy, eval_distribute=mirrored_strategy)
regressor = tf.estimator.LinearRegressor(
    feature_columns=[tf.feature_column.numeric_column('feats')],
    optimizer='SGD',
    config=config)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Initializing RunConfig with distribution strategies.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Not using Distribute Coordinator.
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd
WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpftw63jyd
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}
INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmpftw63jyd', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_device_fn': None, '_protocol': None, '_eval_distribute': <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7f4b0c04c050>, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_distribute_coordinator_mode': None}

در اینجا، شما از یک برآوردگر از پیش ساخته شده استفاده می کنید، اما همان کد با یک تخمینگر سفارشی نیز کار می کند. train_distribute نحوه توزیع آموزش را تعیین می کند و eval_distribute تعیین می کند که ارزیابی چگونه توزیع شود. این یک تفاوت دیگر با Keras است که در آن از استراتژی یکسانی برای تمرین و ارزیابی استفاده می کنید.

اکنون می توانید این برآوردگر را با یک تابع ورودی آموزش و ارزیابی کنید:

def input_fn():
  dataset = tf.data.Dataset.from_tensors(({"feats":[1.]}, [1.]))
  return dataset.repeat(1000).batch(10)
regressor.train(input_fn=input_fn, steps=10)
regressor.evaluate(input_fn=input_fn, steps=10)
INFO:tensorflow:Calling model_fn.
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/data/ops/dataset_ops.py:374: UserWarning: To make it possible to preserve tf.data options across serialization boundaries, their implementation has moved to be part of the TensorFlow graph. As a consequence, the options value is in general no longer known at graph construction time. Invoking this method in graph mode retains the legacy behavior of the original implementation, but note that the returned value might not reflect the actual value of the options.
  warnings.warn("To make it possible to preserve tf.data options across "
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_estimator/python/estimator/util.py:95: DistributedIteratorV1.initialize (from tensorflow.python.distribute.input_lib) is deprecated and will be removed in a future version.
Instructions for updating:
Use the iterator's `initializer` property instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
2021-09-22 20:49:45.706166: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2021-09-22 20:49:45.707521: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:loss = 1.0, step = 0
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 10...
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Saving checkpoints for 10 into /tmp/tmpftw63jyd/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 10...
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Loss for final step: 2.877698e-13.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /replica:0/task:0/device:CPU:0 then broadcast to ('/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46
INFO:tensorflow:Starting evaluation at 2021-09-22T20:49:46
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10
INFO:tensorflow:Restoring parameters from /tmp/tmpftw63jyd/model.ckpt-10
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
2021-09-22 20:49:46.680821: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorFromStringHandle' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorFromStringHandle} }
    .  Registered:  device='CPU'

2021-09-22 20:49:46.682161: W tensorflow/core/grappler/utils/graph_view.cc:836] No registered 'MultiDeviceIteratorGetNextFromShard' OpKernel for GPU devices compatible with node { {node MultiDeviceIteratorGetNextFromShard} }
    .  Registered:  device='CPU'
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [1/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [2/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [3/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [4/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [5/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [6/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [7/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [8/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [9/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Evaluation [10/10]
INFO:tensorflow:Inference Time : 0.26514s
INFO:tensorflow:Inference Time : 0.26514s
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46
INFO:tensorflow:Finished evaluation at 2021-09-22-20:49:46
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving dict for global step 10: average_loss = 1.4210855e-14, global_step = 10, label/mean = 1.0, loss = 1.4210855e-14, prediction/mean = 0.99999994
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 10: /tmp/tmpftw63jyd/model.ckpt-10
{'average_loss': 1.4210855e-14,
 'label/mean': 1.0,
 'loss': 1.4210855e-14,
 'prediction/mean': 0.99999994,
 'global_step': 10}

تفاوت دیگری که باید در اینجا بین Estimator و Keras برجسته شود، مدیریت ورودی است. در Keras، هر دسته از مجموعه داده به طور خودکار در بین چندین نسخه تقسیم می شود. با این حال، در Estimator، شما تقسیم خودکار دسته‌ای را انجام نمی‌دهید، و به‌طور خودکار داده‌ها را بین کارگران مختلف تقسیم نمی‌کنید. شما کنترل کاملی بر نحوه توزیع داده های خود در بین کارگران و دستگاه ها دارید و باید یک input_fn برای تعیین نحوه توزیع داده های خود ارائه دهید.

input_fn شما یک بار به ازای هر کارگر فراخوانی می شود، بنابراین به ازای هر کارگر یک مجموعه داده می دهد. سپس یک دسته از آن مجموعه داده به یک ماکت در آن کارگر داده می شود، در نتیجه N دسته برای N ماکت در 1 کارگر مصرف می شود. به عبارت دیگر، مجموعه داده بازگردانده شده توسط input_fn باید دسته هایی با اندازه PER_REPLICA_BATCH_SIZE ارائه دهد. و اندازه دسته کلی برای یک مرحله را می توان به عنوان PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync به دست آورد.

هنگام انجام آموزش چند کارگری، یا باید داده های خود را بین کارگران تقسیم کنید، یا با یک دانه تصادفی در هر کدام به هم بزنید. نمونه ای از نحوه انجام این کار را می توانید در آموزش Multi-worker with Estimator بررسی کنید.

و به طور مشابه، می توانید از استراتژی های سرور چندکاره و پارامتری نیز استفاده کنید. کد ثابت باقی می ماند، اما باید از tf.estimator.train_and_evaluate استفاده کنید و متغیرهای محیطی TF_CONFIG را برای هر باینری در حال اجرا در خوشه خود تنظیم کنید.

چه چیزی اکنون پشتیبانی می شود؟

پشتیبانی محدودی برای آموزش با Estimator با استفاده از همه استراتژی ها به جز TPUStrategy دارد. آموزش و ارزیابی اولیه باید کار کند، اما تعدادی از ویژگی های پیشرفته مانند v1.train.Scaffold این کار را نمی کند. همچنین ممکن است تعدادی اشکال در این ادغام وجود داشته باشد و هیچ برنامه‌ای برای بهبود فعال این پشتیبانی وجود ندارد (تمرکز بر Keras و پشتیبانی از حلقه آموزشی سفارشی است). در صورت امکان، بهتر است به جای آن از tf.distribute با آن APIها استفاده کنید.

API آموزشی Mirrored Strategy TPUStrategy MultiWorkerMirroredStrategy CentralStorageStrategy ParameterServerStrategy
تخمینگر API پشتیبانی محدود پشتیبانی نشده پشتیبانی محدود پشتیبانی محدود پشتیبانی محدود

نمونه ها و آموزش ها

در اینجا چند مثال سرتاسری وجود دارد که نحوه استفاده از استراتژی‌های مختلف را با برآوردگر نشان می‌دهد:

  1. آموزش Multi-worker Training with Estimator نشان می دهد که چگونه می توانید با چندین کارگر با استفاده از MultiWorkerMirroredStrategy در مجموعه داده MNIST آموزش دهید.
  2. یک مثال سرتاسری از اجرای آموزش چند کارگری با استراتژی‌های توزیع در tensorflow/ecosystem با استفاده از الگوهای Kubernetes. با یک مدل Keras شروع می شود و با استفاده از tf.keras.estimator.model_to_estimator API آن را به یک برآوردگر تبدیل می کند.
  3. مدل رسمی ResNet50 که می تواند با استفاده از MirroredStrategy یا MultiWorkerMirroredStrategy آموزش ببیند.