Bản quyền 2021 Các tác giả TF-Agents.
Xem trên TensorFlow.org | Chạy trong Google Colab | Xem nguồn trên GitHub | Tải xuống sổ ghi chép |
Giới thiệu
Mô hình phổ biến trong học củng cố là thực thi một chính sách trong một môi trường cho một số bước hoặc tập cụ thể. Điều này xảy ra, ví dụ, trong quá trình thu thập dữ liệu, đánh giá và tạo video về đại lý.
Trong khi điều này là tương đối đơn giản để ghi trong python, nó là nhiều hơn nữa phức tạp để viết và gỡ lỗi trong TensorFlow vì nó liên quan đến tf.while
vòng, tf.cond
và tf.control_dependencies
. Vì vậy chúng tôi tóm tắt khái niệm này của một vòng lặp chạy vào một lớp được gọi là driver
, và cung cấp triển khai cũng được thử nghiệm cả về Python và TensorFlow.
Ngoài ra, dữ liệu mà trình điều khiển gặp phải ở mỗi bước được lưu trong một bộ dữ liệu được đặt tên có tên là Trajectory và phát tới một tập hợp các quan sát viên như bộ đệm phát lại và số liệu. Dữ liệu này bao gồm quan sát từ môi trường, hành động được đề xuất bởi chính sách, phần thưởng nhận được, loại hiện tại và bước tiếp theo, v.v.
Thành lập
Nếu bạn chưa cài đặt tf-agent hoặc phòng tập thể dục, hãy chạy:
pip install tf-agents
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.policies import random_py_policy
from tf_agents.policies import random_tf_policy
from tf_agents.metrics import py_metrics
from tf_agents.metrics import tf_metrics
from tf_agents.drivers import py_driver
from tf_agents.drivers import dynamic_episode_driver
Trình điều khiển Python
Các PyDriver
lớp có một môi trường trăn, một chính sách python và một danh sách các nhà quan sát để cập nhật tại mỗi bước. Phương pháp chính được run()
, mà bước môi trường sử dụng hành động từ chính sách cho đến khi ít nhất một trong những tiêu chí chấm dứt sau được đáp ứng: Số lượng các bước đạt max_steps
hoặc số lượng tập phim đạt max_episodes
.
Cách thực hiện đại khái như sau:
class PyDriver(object):
def __init__(self, env, policy, observers, max_steps=1, max_episodes=1):
self._env = env
self._policy = policy
self._observers = observers or []
self._max_steps = max_steps or np.inf
self._max_episodes = max_episodes or np.inf
def run(self, time_step, policy_state=()):
num_steps = 0
num_episodes = 0
while num_steps < self._max_steps and num_episodes < self._max_episodes:
# Compute an action using the policy for the given time_step
action_step = self._policy.action(time_step, policy_state)
# Apply the action to the environment and get the next step
next_time_step = self._env.step(action_step.action)
# Package information into a trajectory
traj = trajectory.Trajectory(
time_step.step_type,
time_step.observation,
action_step.action,
action_step.info,
next_time_step.step_type,
next_time_step.reward,
next_time_step.discount)
for observer in self._observers:
observer(traj)
# Update statistics to check termination
num_episodes += np.sum(traj.is_last())
num_steps += np.sum(~traj.is_boundary())
time_step = next_time_step
policy_state = action_step.state
return time_step, policy_state
Bây giờ, chúng ta hãy xem qua ví dụ về việc chạy một chính sách ngẫu nhiên trên môi trường CartPole, lưu kết quả vào bộ đệm phát lại và tính toán một số chỉ số.
env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(),
action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
env, policy, observers, max_steps=20, max_episodes=1)
initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)
print('Replay Buffer:')
for traj in replay_buffer:
print(traj)
print('Average Return: ', metric.result())
Replay Buffer: Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([-0.01483762, -0.0301547 , -0.02482025, 0.00477367], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(0, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([-0.01544072, 0.16531426, -0.02472478, -0.29563585], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([-0.01213443, 0.3607798 , -0.0306375 , -0.5960129 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([-0.00491884, 0.5563168 , -0.04255775, -0.8981868 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.0062075 , 0.75198895, -0.06052149, -1.2039375 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.02124728, 0.5576993 , -0.08460024, -0.9308191 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.03240127, 0.36381477, -0.10321662, -0.6658752 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.03967756, 0.17026839, -0.11653412, -0.40739253], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.04308293, 0.36683324, -0.12468197, -0.7344236 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.0504196 , 0.17363413, -0.13937044, -0.48343614], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.05389228, -0.0192741 , -0.14903916, -0.23772195], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.0535068 , 0.17762792, -0.1537936 , -0.5734562 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.05705936, 0.37453365, -0.16526273, -0.910366 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.06455003, 0.18198717, -0.18347006, -0.6738478 ], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(1., dtype=float32), 'next_step_type': array(1, dtype=int32), 'observation': array([ 0.06818977, -0.01017502, -0.19694701, -0.44408032], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(0), 'discount': array(0., dtype=float32), 'next_step_type': array(2, dtype=int32), 'observation': array([ 0.06798627, -0.20204504, -0.20582862, -0.21936782], dtype=float32), 'policy_info': (), 'reward': array(1., dtype=float32), 'step_type': array(1, dtype=int32)}) Trajectory( {'action': array(1), 'discount': array(1., dtype=float32), 'next_step_type': array(0, dtype=int32), 'observation': array([ 0.06394537, -0.39372152, -0.21021597, 0.00199082], dtype=float32), 'policy_info': (), 'reward': array(0., dtype=float32), 'step_type': array(2, dtype=int32)}) Average Return: 16.0
Trình điều khiển TensorFlow
Chúng tôi cũng có trình điều khiển trong TensorFlow mà có chức năng tương tự như trình điều khiển Python, nhưng môi trường sử dụng TF, chính sách TF, quan sát TF vv Chúng tôi hiện đang có tài xế 2 TensorFlow: DynamicStepDriver
, mà chấm dứt sau một số lần nhất định bước (có giá trị) và môi trường DynamicEpisodeDriver
, kết thúc sau một số tập nhất định. Hãy để chúng tôi xem xét một ví dụ về DynamicEp Chap đang hoạt động.
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
time_step_spec=tf_env.time_step_spec())
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
tf_env, tf_policy, observers, num_episodes=2)
# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep( {'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=array([[0.01182632, 0.01372784, 0.03056967, 0.04454206]], dtype=float32)>, 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, 'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>}) Number of Steps: 24 Number of Episodes: 2
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep( {'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>, 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[-0.02565088, 0.04813434, -0.04199163, 0.03810809]], dtype=float32)>, 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>, 'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>}) Number of Steps: 70 Number of Episodes: 4