זכויות יוצרים 2021 מחברי TF-Agents.
הצג באתר TensorFlow.org | הפעל בגוגל קולאב | צפה במקור ב-GitHub | הורד מחברת |
מבוא
tf_agents.utils.common.Checkpointer
הוא כלי כדי להציל / לטעון שהמדינה הכשרה, המדינה מדיניות, ומדינת replay_buffer מ / אל אחסון מקומי.
tf_agents.policies.policy_saver.PolicySaver
הוא כלי כדי לשמור / לטעון רק את המדיניות, והוא קל יותר מאשר Checkpointer
. אתה יכול להשתמש PolicySaver
לפרוס את המודל גם ללא כל ידע של קוד שיצרו את המדיניות.
במדריך זה, נשתמש DQN לאמן מודל, ואז להשתמש Checkpointer
ו PolicySaver
כדי להראות איך אנחנו יכולים לאחסן ולטעון המדינות ודגם באופן אינטראקטיבי. הערה כי נשתמש נוסע saved_model החדש של TF2.0 ומתכונת PolicySaver
.
להכין
אם לא התקנת את התלויות הבאות, הרץ:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython
try:
from google.colab import files
except ImportError:
files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()
סוכן DQN
אנחנו הולכים להקים סוכן DQN, בדיוק כמו בקולאב הקודם. הפרטים מוסתרים כברירת מחדל מכיוון שהם אינם חלק הליבה של קולאב זה, אבל אתה יכול ללחוץ על 'הצג קוד' כדי לראות את הפרטים.
היפרפרמטרים
env_name = "CartPole-v1"
collect_steps_per_iteration = 100
replay_buffer_capacity = 100000
fc_layer_params = (100,)
batch_size = 64
learning_rate = 1e-3
log_interval = 5
num_eval_episodes = 10
eval_interval = 1000
סביבה
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
סוֹכֵן
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
agent.initialize()
איסוף נתונים
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
collect_driver = dynamic_step_driver.DynamicStepDriver(
train_env,
agent.collect_policy,
observers=[replay_buffer.add_batch],
num_steps=collect_steps_per_iteration)
# Initial data collection
collect_driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version. Instructions for updating: Use `as_dataset(..., single_deterministic_pass=False) instead.
תאמן את הסוכן
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
def train_one_iteration():
# Collect a few steps using collect_policy and save to the replay buffer.
collect_driver.run()
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
iteration = agent.train_step_counter.numpy()
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))
יצירת וידאו
def embed_gif(gif_buffer):
"""Embeds a gif file in the notebook."""
tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
return IPython.display.HTML(tag)
def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
num_episodes = 3
frames = []
for _ in range(num_episodes):
time_step = eval_tf_env.reset()
frames.append(eval_py_env.render())
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = eval_tf_env.step(action_step.action)
frames.append(eval_py_env.render())
gif_file = io.BytesIO()
imageio.mimsave(gif_file, frames, format='gif', fps=60)
IPython.display.display(embed_gif(gif_file.getvalue()))
צור סרטון
בדוק את ביצועי המדיניות על ידי יצירת סרטון.
print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>
הגדרת Checkpointer ו-PoliceSaver
כעת אנו מוכנים להשתמש ב-Checkpointer וב-PoliceSaver.
מחסום
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
חוסך מדיניות
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
אימון איטרציה אחת
print('Training one iteration....')
train_one_iteration()
Training one iteration.... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) iteration: 1 loss: 1.0214563608169556
שמור למחסום
train_checkpointer.save(global_step)
שחזר את המחסום
כדי שזה יעבוד, יש ליצור מחדש את כל מערך האובייקטים באותו אופן כמו בעת יצירת המחסום.
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
שמור גם מדיניות וייצא למיקום
tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
ניתן לטעון את המדיניות מבלי לדעת באיזה סוכן או רשת נעשה שימוש כדי ליצור אותה. זה הופך את הפריסה של המדיניות להרבה יותר קלה.
טען את המדיניות השמורה ובדוק כיצד היא פועלת
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
יצוא ויבוא
שאר הקולאב יעזור לך לייצא / לייבא ספריות ביקורת ומדיניות כך שתוכל להמשיך באימונים בשלב מאוחר יותר ולפרוס את המודל מבלי שתצטרך להתאמן שוב.
עכשיו אתה יכול לחזור ל'אימון איטרציה אחת' ולהתאמן עוד כמה פעמים כך שתוכל להבין את ההבדל בהמשך. ברגע שתתחיל לראות תוצאות קצת יותר טובות, המשך למטה.
צור קובץ zip והעלה קובץ zip (לחץ פעמיים כדי לראות את הקוד)
def create_zip_file(dirname, base_filename):
return shutil.make_archive(base_filename, 'zip', dirname)
def upload_and_unzip_file_to(dirname):
if files is None:
return
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
shutil.rmtree(dirname)
zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
zip_files.extractall(dirname)
zip_files.close()
צור קובץ דחוס מספריית המחסום.
train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))
הורד את קובץ ה-zip.
if files is not None:
files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
לאחר אימון במשך זמן מה (10-15 פעמים), הורד את קובץ ה-zip של המחסום, ועבור אל "זמן ריצה > הפעל מחדש והפעל הכל" כדי לאפס את האימון, וחזור לתא זה. כעת תוכל להעלות את קובץ ה-zip שהורדת, ולהמשיך בהדרכה.
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
לאחר שהעלית את ספריית המחסומים, חזור אל 'אימון איטרציה אחת' כדי להמשיך באימון או חזור אל 'צור סרטון' כדי לבדוק את ביצועי המדיניות הנטענת.
לחילופין, ניתן לשמור את המדיניות (הדגם) ולשחזר אותה. בניגוד למחסום, אתה לא יכול להמשיך באימון, אבל אתה עדיין יכול לפרוס את המודל. שימו לב שהקובץ שהורד קטן בהרבה מזה של המחסום.
tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
העלה את ספריית המדיניות שהורדת (exported_policy.zip) ובדוק כיצד פועלת המדיניות השמורה.
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
SavedModelPyTFEagerPolicy
אם אתה לא רוצה להשתמש במדיניות TF, אז אתה יכול גם להשתמש saved_model ישירות עם env Python באמצעות py_tf_eager_policy.SavedModelPyTFEagerPolicy
.
שים לב שזה עובד רק כאשר מצב להוט מופעל.
eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())
# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)
המרת מדיניות ל- TFLite
ראה ממיר לייט TensorFlow לפרטים נוספים.
converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
הפעל מסקנות על מודל TFLite
ראה הסקת לייט TensorFlow לפרטים נוספים.
import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))
policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
'0/discount':tf.constant(0.0),
'0/observation':tf.zeros([1,4]),
'0/reward':tf.constant(0.0),
'0/step_type':tf.constant(0)})
{'action': array([0])}