Copyright 2018 The TF-Agents Authors.
TensorFlow.orgで表示 | Google Colabで実行 | GitHub でソースを表示{ | ノートブックをダウンロード/a0} |
はじめに
強化学習 (RL) の目標は、環境と対話することにより学習するエージェントを設計することです。標準的な RL のセットアップでは、エージェントは各タイムステップで観測を受け取り、行動を選択します。行動は環境に適用され、環境は報酬と新しい観察を返します。 エージェントは、報酬の合計 (リターン) を最大化する行動を選択するポリシーをトレーニングします。
TF-Agent では、環境は Python または TensorFlow で実装できます。通常、Python 環境はより分かりやすく、実装やデバッグが簡単ですが、TensorFlow 環境はより効率的で自然な並列化が可能です。最も一般的なワークフローは、Python で環境を実装し、ラッパーを使用して自動的に TensorFlow に変換することです。
最初に Python 環境を見てみましょう。TensorFlow 環境の API もよく似ています。
セットアップ
TF-Agent または gym をまだインストールしていない場合は、以下を実行します。
pip install -q tf-agents
pip install -q 'gym==0.10.11'
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
tf.compat.v1.enable_v2_behavior()
Python 環境
Python環境には、環境に行動を適用し、次のステップに関する以下の情報を返す step(action) -> next_time_step
メソッドがあります。
observation
:これは、エージェントが次のステップで行動を選択するために観察できる環境状態の一部です。reward
:エージェントは、複数のステップにわたってこれらの報酬の合計を最大化することを学習します。step_type
:環境との相互作用は通常、シーケンス/エピソードの一部です (チェスのゲームで複数の動きがあるように)。step_type は、FIRST
、MID
またはLAST
のいずれかで、このタイムステップがシーケンスの最初、中間、または最後のステップかどうかを示します。discount
:これは、現在のタイムステップでの報酬に対する次のタイムステップでの報酬の重み付けを表す浮動小数です。
これらは、名前付きタプルTimeStep(step_type, reward, discount, observation)
にグループ化されます。
すべての Python 環境で実装する必要があるインターフェースは、environments/py_environment.PyEnvironment
です。主なメソッドは、以下のとおりです。
class PyEnvironment(object):
def reset(self):
"""Return initial_time_step."""
self._current_time_step = self._reset()
return self._current_time_step
def step(self, action):
"""Apply action and return new time_step."""
if self._current_time_step is None:
return self.reset()
self._current_time_step = self._step(action)
return self._current_time_step
def current_time_step(self):
return self._current_time_step
def time_step_spec(self):
"""Return time_step_spec."""
@abc.abstractmethod
def observation_spec(self):
"""Return observation_spec."""
@abc.abstractmethod
def action_spec(self):
"""Return action_spec."""
@abc.abstractmethod
def _reset(self):
"""Return initial_time_step."""
@abc.abstractmethod
def _step(self, action):
"""Apply action and return new time_step."""
self._current_time_step = self._step(action)
return self._current_time_step
step()
メソッドに加えて、環境では、新しいシーケンスを開始して新規TimeStep
を提供するreset()
メソッドも提供されます。reset
メソッドを明示的に呼び出す必要はありません。エピソードの最後、またはstep()が初めて呼び出されたときに、環境は自動的にリセットされると想定されています。
サブクラスはstep()
またはreset()
を直接実装しないことに注意してください。代わりに、_step()
および_reset()
メソッドをオーバーライドします。これらのメソッドから返されたタイムステップはキャッシュされ、current_time_step()
を通じて公開されます。
observation_spec
およびaction_spec
メソッドは(Bounded)ArraySpecs
のネストを返します。このネストは観測と行動の名前、形状、データ型、範囲をそれぞれ記述します。
TF-Agent では、リスト、タプル、名前付きタプル、またはディクショナリからなるツリー構造で定義されるネストを繰り返し参照します。これらは、観察と行動の構造を維持するために任意に構成できます。これは、多くの観察と行動がある、より複雑な環境で非常に役立ちます。
標準環境の使用
TF Agent には、py_environment.PyEnvironment
インターフェースに準拠するように、OpenAI Gym、DeepMind-control、Atari などの多くの標準環境用のラッパーが組み込まれていています。これらのラップされた環境は、環境スイートを使用して簡単に読み込めます。OpenAI Gym から CartPole 環境を読み込み、行動と time_step_spec を見てみましょう。
environment = suite_gym.load('CartPole-v0')
print('action_spec:', environment.action_spec())
print('time_step_spec.observation:', environment.time_step_spec().observation)
print('time_step_spec.step_type:', environment.time_step_spec().step_type)
print('time_step_spec.discount:', environment.time_step_spec().discount)
print('time_step_spec.reward:', environment.time_step_spec().reward)
action_spec: BoundedArraySpec(shape=(), dtype=dtype('int64'), name='action', minimum=0, maximum=1) time_step_spec.observation: BoundedArraySpec(shape=(4,), dtype=dtype('float32'), name='observation', minimum=[-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], maximum=[4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38]) time_step_spec.step_type: ArraySpec(shape=(), dtype=dtype('int32'), name='step_type') time_step_spec.discount: BoundedArraySpec(shape=(), dtype=dtype('float32'), name='discount', minimum=0.0, maximum=1.0) time_step_spec.reward: ArraySpec(shape=(), dtype=dtype('float32'), name='reward')
環境は [0, 1] のint64
タイプの行動を予期し、 TimeSteps
を返します。観測値は長さ 4 のfloat32
ベクトルであり、割引係数は [0.0, 1.0] のfloat32
です。では、エピソード全体に対して固定した行動(1,)
を実行してみましょう。
action = np.array(1, dtype=np.int32)
time_step = environment.reset()
print(time_step)
while not time_step.is_last():
time_step = environment.step(action)
print(time_step)
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.02262211, -0.01128484, -0.01841794, 0.00293427], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.02239642, 0.18409634, -0.01835925, -0.29550236], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.02607834, 0.37947515, -0.0242693 , -0.5939185 ], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.03366784, 0.5749282 , -0.03614767, -0.89414626], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.04516641, 0.7705213 , -0.05403059, -1.1979693 ], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.06057683, 0.96629924, -0.07798998, -1.5070848 ], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.07990282, 1.1622754 , -0.10813168, -1.8230615 ], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.10314833, 1.358419 , -0.14459291, -2.1472874 ], dtype=float32)) TimeStep(step_type=array(1, dtype=int32), reward=array(1., dtype=float32), discount=array(1., dtype=float32), observation=array([ 0.1303167 , 1.5546396 , -0.18753865, -2.480909 ], dtype=float32)) TimeStep(step_type=array(2, dtype=int32), reward=array(1., dtype=float32), discount=array(0., dtype=float32), observation=array([ 0.1614095 , 1.7507701 , -0.23715684, -2.8247602 ], dtype=float32))
独自 Python 環境の作成
多くの場合、一般的に、TF-Agent の標準エージェント (agents/を参照) の 1 つが問題に適用されます。そのためには、問題を環境としてまとめる必要があります。 次に、Python で環境を実装する方法を見てみましょう。
次の (ブラックジャックのような ) カードゲームをプレイするようにエージェントをトレーニングするとします。
- ゲームは、1~10 の数値が付けられた無限のカード一式を使用してプレイします。
- 毎回、エージェントは2つの行動 (新しいランダムカードを取得する、またはその時点のラウンドを停止する) を実行できます。
- 目標はラウンド終了時にカードの合計を 21 にできるだけ近づけることです。
ゲームを表す環境は次のようになります。
- 行動:2 つの行動があります( 行動 0:新しいカードを取得、行動1:その時点のラウンドを終了)。
- 観察:その時点のラウンドのカードの合計。
- 報酬:目標は、21 にできるだけ近づけることなので、ラウンド終了時に次の報酬を使用します。sum_of_cards - 21 if sum_of_cards <= 21, else -21
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._action_spec = array_spec.BoundedArraySpec(
shape=(), dtype=np.int32, minimum=0, maximum=1, name='action')
self._observation_spec = array_spec.BoundedArraySpec(
shape=(1,), dtype=np.int32, minimum=0, name='observation')
self._state = 0
self._episode_ended = False
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def _reset(self):
self._state = 0
self._episode_ended = False
return ts.restart(np.array([self._state], dtype=np.int32))
def _step(self, action):
if self._episode_ended:
# The last action ended the episode. Ignore the current action and start
# a new episode.
return self.reset()
# Make sure episodes don't go on forever.
if action == 1:
self._episode_ended = True
elif action == 0:
new_card = np.random.randint(1, 11)
self._state += new_card
else:
raise ValueError('`action` should be 0 or 1.')
if self._episode_ended or self._state >= 21:
reward = self._state - 21 if self._state <= 21 else -21
return ts.termination(np.array([self._state], dtype=np.int32), reward)
else:
return ts.transition(
np.array([self._state], dtype=np.int32), reward=0.0, discount=1.0)
上記の環境がすべて正しく定義されていることを確認しましょう。独自の環境を作成する場合、生成された観測と time_steps が仕様で定義されている正しい形状とタイプに従っていることを確認する必要があります。これらは TensorFlow グラフの生成に使用されるため、問題が発生するとデバッグが困難になる可能性があります。
この環境を検証するために、ランダムなポリシーを使用して行動を生成し、5 つのエピソードでイテレーションを実行し、意図したとおりに機能していることを確認します。環境の仕様に従っていない time_step を受け取ると、エラーが発生します。
environment = CardGameEnv()
utils.validate_py_environment(environment, episodes=5)
環境が意図するとおりに機能していることが確認できたので、固定ポリシーを使用してこの環境を実行してみましょう。3 枚のカードを要求して、ラウンドを終了します。
get_new_card_action = np.array(0, dtype=np.int32)
end_round_action = np.array(1, dtype=np.int32)
environment = CardGameEnv()
time_step = environment.reset()
print(time_step)
cumulative_reward = time_step.reward
for _ in range(3):
time_step = environment.step(get_new_card_action)
print(time_step)
cumulative_reward += time_step.reward
time_step = environment.step(end_round_action)
print(time_step)
cumulative_reward += time_step.reward
print('Final Reward = ', cumulative_reward)
TimeStep(step_type=array(0, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([0], dtype=int32)) TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([6], dtype=int32)) TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([11], dtype=int32)) TimeStep(step_type=array(1, dtype=int32), reward=array(0., dtype=float32), discount=array(1., dtype=float32), observation=array([19], dtype=int32)) TimeStep(step_type=array(2, dtype=int32), reward=array(-2., dtype=float32), discount=array(0., dtype=float32), observation=array([19], dtype=int32)) Final Reward = -2.0
環境ラッパー
環境ラッパーは Python 環境を取り、環境の変更されたバージョンを返します。元の環境と変更された環境はどちらもpy_environment.PyEnvironment
のインスタンスであり、複数のラッパーをチェーン化することもできます。
一般的なラッパーはenvironments/wrappers.py
にあります。 例:
ActionDiscretizeWrapper
:連続空間で定義された行動を離散化された行動に変換します。RunStats
: 実行したステップ数、完了したエピソード数など、環境の実行統計をキャプチャします。TimeLimit
:一定のステップ数の後にエピソードを終了します。
例1:行動離散化ラッパー
InvertedPendulumは、[-2, 2]
の範囲の連続行動を受け入れる PyBullet 環境です。この環境で DQN などの離散行動エージェントをトレーニングする場合は、行動空間を離散化(量子化)する必要があります。ActionDiscretizeWrapper
は、これを行います。ラップ前とラップ後のaction_spec
を比較しましょう。
env = suite_gym.load('Pendulum-v0')
print('Action Spec:', env.action_spec())
discrete_action_env = wrappers.ActionDiscretizeWrapper(env, num_actions=5)
print('Discretized Action Spec:', discrete_action_env.action_spec())
Action Spec: BoundedArraySpec(shape=(1,), dtype=dtype('float32'), name='action', minimum=-2.0, maximum=2.0) Discretized Action Spec: BoundedArraySpec(shape=(), dtype=dtype('int32'), name='action', minimum=0, maximum=4)
ラップされたdiscrete_action_env
は、py_environment.PyEnvironment
のインスタンスで、通常の python 環境のように扱うことができます。
TensorFlow 環境
TF 環境のインターフェースはenvironments/tf_environment.TFEnvironment
で定義されており、Python 環境とよく似ています。ただし、TF 環境は以下の点で Python 環境と異なります。
- 配列の代わりにテンソルオブジェクトを生成する
- TF 環境は、仕様と比較したときに生成されたテンソルにバッチディメンションを追加します
Python環境をTF環境に変換すると、TensorFlowで操作を並列化できます。たとえば、環境からデータを収集してreplay_buffer
に追加するcollect_experience_op
、および、replay_buffer
から読み取り、エージェントをトレーニングするtrain_op
を定義し、TensorFlowで自然に並列実行することができます。
class TFEnvironment(object):
def time_step_spec(self):
"""Describes the `TimeStep` tensors returned by `step()`."""
def observation_spec(self):
"""Defines the `TensorSpec` of observations provided by the environment."""
def action_spec(self):
"""Describes the TensorSpecs of the action expected by `step(action)`."""
def reset(self):
"""Returns the current `TimeStep` after resetting the Environment."""
return self._reset()
def current_time_step(self):
"""Returns the current `TimeStep`."""
return self._current_time_step()
def step(self, action):
"""Applies the action and returns the new `TimeStep`."""
return self._step(action)
@abc.abstractmethod
def _reset(self):
"""Returns the current `TimeStep` after resetting the Environment."""
@abc.abstractmethod
def _current_time_step(self):
"""Returns the current `TimeStep`."""
@abc.abstractmethod
def _step(self, action):
"""Applies the action and returns the new `TimeStep`."""
current_time_step()
メソッドは現在の time_step を返し、必要に応じて環境を初期化します。
reset()
メソッドは環境を強制的にリセットし、current_step を返します。
action
が以前のtime_step
に依存しない場合、Graph
モードではtf.control_dependency
が必要です。
ここでは、TFEnvironments
を作成する方法を見ていきます。
独自 TensorFlow 環境の作成
これは Python で環境を作成するよりも複雑であるため、この Colab では取り上げません。例はこちらからご覧いただけます。より一般的な使用例は、Python で環境を実装し、TFPyEnvironment
ラッパーを使用して TensorFlow でラップすることです (以下を参照)。
TensorFlow で Python 環境をラップ
TFPyEnvironment
ラッパーを使用すると、任意の Python 環境を TensorFlow 環境に簡単にラップできます。
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
print(isinstance(tf_env, tf_environment.TFEnvironment))
print("TimeStep Specs:", tf_env.time_step_spec())
print("Action Specs:", tf_env.action_spec())
True TimeStep Specs: TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(4,), dtype=tf.float32, name='observation', minimum=array([-4.8000002e+00, -3.4028235e+38, -4.1887903e-01, -3.4028235e+38], dtype=float32), maximum=array([4.8000002e+00, 3.4028235e+38, 4.1887903e-01, 3.4028235e+38], dtype=float32))) Action Specs: BoundedTensorSpec(shape=(), dtype=tf.int64, name='action', minimum=array(0), maximum=array(1))
仕様のタイプが(Bounded)TensorSpec
になっていることに注意してください。
使用例
簡単な例
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
# reset() creates the initial time_step after resetting the environment.
time_step = tf_env.reset()
num_steps = 3
transitions = []
reward = 0
for i in range(num_steps):
action = tf.constant([i % 2])
# applies the action and returns the new TimeStep.
next_time_step = tf_env.step(action)
transitions.append([time_step, action, next_time_step])
reward += next_time_step.reward
time_step = next_time_step
np_transitions = tf.nest.map_structure(lambda x: x.numpy(), transitions)
print('\n'.join(map(str, np_transitions)))
print('Total reward:', reward.numpy())
[TimeStep(step_type=array([0], dtype=int32), reward=array([0.], dtype=float32), discount=array([1.], dtype=float32), observation=array([[ 0.02205312, -0.02536285, 0.03750031, -0.03571422]], dtype=float32)), array([0], dtype=int32), TimeStep(step_type=array([1], dtype=int32), reward=array([1.], dtype=float32), discount=array([1.], dtype=float32), observation=array([[ 0.02154586, -0.22100194, 0.03678603, 0.2685606 ]], dtype=float32))] [TimeStep(step_type=array([1], dtype=int32), reward=array([1.], dtype=float32), discount=array([1.], dtype=float32), observation=array([[ 0.02154586, -0.22100194, 0.03678603, 0.2685606 ]], dtype=float32)), array([1], dtype=int32), TimeStep(step_type=array([1], dtype=int32), reward=array([1.], dtype=float32), discount=array([1.], dtype=float32), observation=array([[ 0.01712583, -0.02642375, 0.04215724, -0.01229658]], dtype=float32))] [TimeStep(step_type=array([1], dtype=int32), reward=array([1.], dtype=float32), discount=array([1.], dtype=float32), observation=array([[ 0.01712583, -0.02642375, 0.04215724, -0.01229658]], dtype=float32)), array([0], dtype=int32), TimeStep(step_type=array([1], dtype=int32), reward=array([1.], dtype=float32), discount=array([1.], dtype=float32), observation=array([[ 0.01659735, -0.22212414, 0.04191131, 0.29338375]], dtype=float32))] Total reward: [3.]
全エピソード
env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
time_step = tf_env.reset()
rewards = []
steps = []
num_episodes = 5
for _ in range(num_episodes):
episode_reward = 0
episode_steps = 0
while not time_step.is_last():
action = tf.random.uniform([1], 0, 2, dtype=tf.int32)
time_step = tf_env.step(action)
episode_steps += 1
episode_reward += time_step.reward.numpy()
rewards.append(episode_reward)
steps.append(episode_steps)
time_step = tf_env.reset()
num_steps = np.sum(steps)
avg_length = np.mean(steps)
avg_reward = np.mean(rewards)
print('num_episodes:', num_episodes, 'num_steps:', num_steps)
print('avg_length', avg_length, 'avg_reward:', avg_reward)
num_episodes: 5 num_steps: 87 avg_length 17.4 avg_reward: 17.4