কপিরাইট 2021 টিএফ-এজেন্ট লেখক।
TensorFlow.org এ দেখুন | Google Colab-এ চালান | GitHub-এ উৎস দেখুন | নোটবুক ডাউনলোড করুন |
ভূমিকা
tf_agents.utils.common.Checkpointer
সংরক্ষণ / একটি স্থানীয় সংগ্রহস্থল থেকে প্রশিক্ষণ রাষ্ট্র, নীতি রাষ্ট্র, এবং / replay_buffer রাষ্ট্র লোড করতে একটি ইউটিলিটি।
tf_agents.policies.policy_saver.PolicySaver
সংরক্ষণ করতে / লোড শুধুমাত্র নীতিটি টুল, এবং চেয়ে অনেক লঘুতর Checkpointer
। আপনি ব্যবহার করতে পারেন PolicySaver
কোডটি নীতি তৈরি কোন অজ্ঞাতসারে পাশাপাশি মডেল স্থাপন করা।
এই টিউটোরিয়াল, আমরা DQN ব্যবহার একটি মডেল প্রশিক্ষণ, তারপর ব্যবহার করা হবে Checkpointer
এবং PolicySaver
প্রদর্শনী আমরা কিভাবে সংরক্ষণ এবং একটি ইন্টারেক্টিভ ভাবে রাজ্য এবং মডেল লোড পারেন। মনে রাখবেন আমরা TF2.0 এর নতুন saved_model সাধনী দ্বারা প্রয়োগকরণ এবং জন্য ফর্ম্যাট ব্যবহার করবে PolicySaver
।
সেটআপ
আপনি যদি নিম্নলিখিত নির্ভরতাগুলি ইনস্টল না করে থাকেন তবে চালান:
sudo apt-get update
sudo apt-get install -y xvfb ffmpeg python-opengl
pip install pyglet
pip install 'imageio==2.4.0'
pip install 'xvfbwrapper==0.2.9'
pip install tf-agents[reverb]
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import base64
import imageio
import io
import matplotlib
import matplotlib.pyplot as plt
import os
import shutil
import tempfile
import tensorflow as tf
import zipfile
import IPython
try:
from google.colab import files
except ImportError:
files = None
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import policy_saver
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common
tempdir = os.getenv("TEST_TMPDIR", tempfile.gettempdir())
# Set up a virtual display for rendering OpenAI gym environments.
import xvfbwrapper
xvfbwrapper.Xvfb(1400, 900, 24).start()
DQN এজেন্ট
আমরা আগের কোলাবের মতই DQN এজেন্ট সেট আপ করতে যাচ্ছি। বিশদগুলি ডিফল্টরূপে লুকানো থাকে কারণ সেগুলি এই কোল্যাবের মূল অংশ নয়, তবে আপনি বিস্তারিত দেখতে 'কোড দেখান' এ ক্লিক করতে পারেন৷
হাইপারপ্যারামিটার
env_name = "CartPole-v1"
collect_steps_per_iteration = 100
replay_buffer_capacity = 100000
fc_layer_params = (100,)
batch_size = 64
learning_rate = 1e-3
log_interval = 5
num_eval_episodes = 10
eval_interval = 1000
পরিবেশ
train_py_env = suite_gym.load(env_name)
eval_py_env = suite_gym.load(env_name)
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
প্রতিনিধি
q_net = q_network.QNetwork(
train_env.observation_spec(),
train_env.action_spec(),
fc_layer_params=fc_layer_params)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
global_step = tf.compat.v1.train.get_or_create_global_step()
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=global_step)
agent.initialize()
তথ্য সংগ্রহ
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=train_env.batch_size,
max_length=replay_buffer_capacity)
collect_driver = dynamic_step_driver.DynamicStepDriver(
train_env,
agent.collect_policy,
observers=[replay_buffer.add_batch],
num_steps=collect_steps_per_iteration)
# Initial data collection
collect_driver.run()
# Dataset generates trajectories with shape [BxTx...] where
# T = n_step_update + 1.
dataset = replay_buffer.as_dataset(
num_parallel_calls=3, sample_batch_size=batch_size,
num_steps=2).prefetch(3)
iterator = iter(dataset)
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/autograph/impl/api.py:383: ReplayBuffer.get_next (from tf_agents.replay_buffers.replay_buffer) is deprecated and will be removed in a future version. Instructions for updating: Use `as_dataset(..., single_deterministic_pass=False) instead.
এজেন্টকে প্রশিক্ষণ দিন
# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)
def train_one_iteration():
# Collect a few steps using collect_policy and save to the replay buffer.
collect_driver.run()
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(iterator)
train_loss = agent.train(experience)
iteration = agent.train_step_counter.numpy()
print ('iteration: {0} loss: {1}'.format(iteration, train_loss.loss))
ভিডিও জেনারেশন
def embed_gif(gif_buffer):
"""Embeds a gif file in the notebook."""
tag = '<img src="data:image/gif;base64,{0}"/>'.format(base64.b64encode(gif_buffer).decode())
return IPython.display.HTML(tag)
def run_episodes_and_create_video(policy, eval_tf_env, eval_py_env):
num_episodes = 3
frames = []
for _ in range(num_episodes):
time_step = eval_tf_env.reset()
frames.append(eval_py_env.render())
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = eval_tf_env.step(action_step.action)
frames.append(eval_py_env.render())
gif_file = io.BytesIO()
imageio.mimsave(gif_file, frames, format='gif', fps=60)
IPython.display.display(embed_gif(gif_file.getvalue()))
একটি ভিডিও তৈরি করুন
একটি ভিডিও তৈরি করে নীতির কার্যকারিতা পরীক্ষা করুন৷
print ('global_step:')
print (global_step)
run_episodes_and_create_video(agent.policy, eval_env, eval_py_env)
global_step: <tf.Variable 'global_step:0' shape=() dtype=int64, numpy=0>
চেকপয়েন্টার এবং পলিসিসেভার সেটআপ করুন
এখন আমরা Checkpointer এবং PolicySaver ব্যবহার করার জন্য প্রস্তুত।
চেকপয়েন্টার
checkpoint_dir = os.path.join(tempdir, 'checkpoint')
train_checkpointer = common.Checkpointer(
ckpt_dir=checkpoint_dir,
max_to_keep=1,
agent=agent,
policy=agent.policy,
replay_buffer=replay_buffer,
global_step=global_step
)
পলিসি সেভার
policy_dir = os.path.join(tempdir, 'policy')
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
2022-01-20 12:15:14.054931: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
এক পুনরাবৃত্তি ট্রেন
print('Training one iteration....')
train_one_iteration()
Training one iteration.... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py:1096: calling foldr_v2 (from tensorflow.python.ops.functional_ops) with back_prop=False is deprecated and will be removed in a future version. Instructions for updating: back_prop=False is deprecated. Consider using tf.stop_gradient instead. Instead of: results = tf.foldr(fn, elems, back_prop=False) Use: results = tf.nest.map_structure(tf.stop_gradient, tf.foldr(fn, elems)) iteration: 1 loss: 1.0214563608169556
চেকপয়েন্টে সংরক্ষণ করুন
train_checkpointer.save(global_step)
চেকপয়েন্ট পুনরুদ্ধার করুন
এটি কাজ করার জন্য, চেকপয়েন্ট তৈরি করার সময় বস্তুর পুরো সেটটিকে একইভাবে পুনরায় তৈরি করা উচিত।
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
এছাড়াও নীতি সংরক্ষণ করুন এবং একটি অবস্থানে রপ্তানি করুন
tf_policy_saver.save(policy_dir)
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
পলিসিটি তৈরি করার জন্য কোন এজেন্ট বা নেটওয়ার্ক ব্যবহার করা হয়েছে সে সম্পর্কে কোনো জ্ঞান ছাড়াই লোড করা যেতে পারে। এটি নীতির স্থাপনাকে অনেক সহজ করে তোলে।
সংরক্ষিত নীতি লোড করুন এবং এটি কীভাবে কাজ করে তা পরীক্ষা করুন
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
রপ্তানি এবং আমদানি
বাকি কোল্যাব আপনাকে চেকপয়েন্টার এবং নীতি নির্দেশিকাগুলি রপ্তানি/আমদানি করতে সাহায্য করবে যাতে আপনি পরবর্তী সময়ে প্রশিক্ষণ চালিয়ে যেতে পারেন এবং পুনরায় প্রশিক্ষণ না নিয়েই মডেলটি স্থাপন করতে পারেন।
এখন আপনি 'ট্রেন ওয়ান আইটারেশন'-এ ফিরে যেতে পারেন এবং আরও কয়েকবার প্রশিক্ষণ দিতে পারেন যাতে আপনি পরে পার্থক্য বুঝতে পারেন। একবার আপনি একটু ভালো ফলাফল দেখতে শুরু করলে, নিচে চালিয়ে যান।
জিপ ফাইল তৈরি করুন এবং জিপ ফাইল আপলোড করুন (কোডটি দেখতে ডাবল ক্লিক করুন)
def create_zip_file(dirname, base_filename):
return shutil.make_archive(base_filename, 'zip', dirname)
def upload_and_unzip_file_to(dirname):
if files is None:
return
uploaded = files.upload()
for fn in uploaded.keys():
print('User uploaded file "{name}" with length {length} bytes'.format(
name=fn, length=len(uploaded[fn])))
shutil.rmtree(dirname)
zip_files = zipfile.ZipFile(io.BytesIO(uploaded[fn]), 'r')
zip_files.extractall(dirname)
zip_files.close()
চেকপয়েন্ট ডিরেক্টরি থেকে একটি জিপ করা ফাইল তৈরি করুন।
train_checkpointer.save(global_step)
checkpoint_zip_filename = create_zip_file(checkpoint_dir, os.path.join(tempdir, 'exported_cp'))
জিপ ফাইলটি ডাউনলোড করুন।
if files is not None:
files.download(checkpoint_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
কিছু সময়ের জন্য প্রশিক্ষণের পর (10-15 বার), চেকপয়েন্ট জিপ ফাইলটি ডাউনলোড করুন এবং ট্রেনিং রিসেট করতে "Runtime > Restart and run all" এ যান এবং এই সেলে ফিরে আসুন। এখন আপনি ডাউনলোড করা জিপ ফাইল আপলোড করতে পারেন, এবং প্রশিক্ষণ চালিয়ে যেতে পারেন।
upload_and_unzip_file_to(checkpoint_dir)
train_checkpointer.initialize_or_restore()
global_step = tf.compat.v1.train.get_global_step()
একবার আপনি চেকপয়েন্ট ডিরেক্টরি আপলোড করলে, প্রশিক্ষণ চালিয়ে যেতে 'একটি পুনরাবৃত্তির ট্রেন'-এ ফিরে যান বা লোড করা নীতির কার্যকারিতা পরীক্ষা করতে 'একটি ভিডিও তৈরি করুন'-এ ফিরে যান।
বিকল্পভাবে, আপনি নীতি (মডেল) সংরক্ষণ করতে পারেন এবং এটি পুনরুদ্ধার করতে পারেন। চেকপয়েন্টারের বিপরীতে, আপনি প্রশিক্ষণ চালিয়ে যেতে পারবেন না, তবে আপনি এখনও মডেলটি স্থাপন করতে পারেন। উল্লেখ্য যে ডাউনলোড করা ফাইলটি চেকপয়েন্টারের তুলনায় অনেক ছোট।
tf_policy_saver.save(policy_dir)
policy_zip_filename = create_zip_file(policy_dir, os.path.join(tempdir, 'exported_policy'))
WARNING:absl:Function `function_with_signature` contains input name(s) 0/step_type, 0/reward, 0/discount, 0/observation with unsupported characters which will be renamed to step_type, reward, discount, observation in the SavedModel. WARNING:absl:Found untraced functions such as QNetwork_layer_call_fn, QNetwork_layer_call_and_return_conditional_losses, EncodingNetwork_layer_call_fn, EncodingNetwork_layer_call_and_return_conditional_losses, dense_1_layer_call_fn while saving (showing 5 of 25). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: /tmp/policy/assets /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/saved_model/nested_structure_coder.py:561: UserWarning: Encoding a StructuredValue with type tf_agents.policies.greedy_policy.DeterministicWithLogProb_ACTTypeSpec; loading this StructuredValue will require that this type be imported and registered. "imported and registered." % type_spec_class_name) INFO:tensorflow:Assets written to: /tmp/policy/assets
if files is not None:
files.download(policy_zip_filename) # try again if this fails: https://github.com/googlecolab/colabtools/issues/469
ডাউনলোড করা পলিসি ডাইরেক্টরি (exported_policy.zip) আপলোড করুন এবং সেভ করা পলিসি কীভাবে কাজ করে তা দেখুন।
upload_and_unzip_file_to(policy_dir)
saved_policy = tf.saved_model.load(policy_dir)
run_episodes_and_create_video(saved_policy, eval_env, eval_py_env)
SavedModelPyTFEagerPolicy
আপনি মেমরি নীতি ব্যবহার করতে না চান, তাহলে আপনার কাছে saved_model সরাসরি পাইথন env সঙ্গে ব্যবহারের মাধ্যমে ব্যবহার করতে পারেন py_tf_eager_policy.SavedModelPyTFEagerPolicy
।
মনে রাখবেন যে এটি শুধুমাত্র তখনই কাজ করে যখন ইজার মোড সক্ষম থাকে৷
eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
policy_dir, eval_py_env.time_step_spec(), eval_py_env.action_spec())
# Note that we're passing eval_py_env not eval_env.
run_episodes_and_create_video(eager_py_policy, eval_py_env, eval_py_env)
নীতিকে TFLite-এ রূপান্তর করুন
দেখুন TensorFlow লাইট রূপান্তরকারী আরো বিস্তারিত জানার জন্য।
converter = tf.lite.TFLiteConverter.from_saved_model(policy_dir, signature_keys=["action"])
tflite_policy = converter.convert()
with open(os.path.join(tempdir, 'policy.tflite'), 'wb') as f:
f.write(tflite_policy)
2022-01-20 12:15:59.646042: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:363] Ignored output_format. 2022-01-20 12:15:59.646082: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:366] Ignored drop_control_dependency. 2022-01-20 12:15:59.646088: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:372] Ignored change_concat_input_ranges. WARNING:absl:Buffer deduplication procedure will be skipped when flatbuffer library is not properly loaded
TFLite মডেলে অনুমান চালান
দেখুন TensorFlow লাইট ইনফিরেনস আরো বিস্তারিত জানার জন্য।
import numpy as np
interpreter = tf.lite.Interpreter(os.path.join(tempdir, 'policy.tflite'))
policy_runner = interpreter.get_signature_runner()
print(policy_runner._inputs)
{'0/discount': 1, '0/observation': 2, '0/reward': 3, '0/step_type': 0}
policy_runner(**{
'0/discount':tf.constant(0.0),
'0/observation':tf.zeros([1,4]),
'0/reward':tf.constant(0.0),
'0/step_type':tf.constant(0)})
{'action': array([0])}