A driver that runs a python policy in a python environment.
Inherits From: Driver
tf_agents.drivers.py_driver.PyDriver(
env: tf_agents.environments.PyEnvironment
,
policy: tf_agents.policies.py_policy.PyPolicy
,
observers: Sequence[Callable[[trajectory.Trajectory], Any]],
transition_observers: Optional[Sequence[Callable[[trajectory.Transition], Any]]] = None,
info_observers: Optional[Sequence[Callable[[Any], Any]]] = None,
max_steps: Optional[types.Int] = None,
max_episodes: Optional[types.Int] = None,
end_episode_on_boundary: bool = True
)
Used in the notebooks
Args |
env
|
A py_environment.Base environment.
|
policy
|
A py_policy.PyPolicy policy.
|
observers
|
A list of observers that are notified after every step in the
environment. Each observer is a callable(trajectory.Trajectory).
|
transition_observers
|
A list of observers that are updated after every
step in the environment. Each observer is a callable((TimeStep,
PolicyStep, NextTimeStep)). The transition is shaped just as
trajectories are for regular observers.
|
info_observers
|
A list of observers that are notified after every step in
the environment. Each observer is a callable(info).
|
max_steps
|
Optional maximum number of steps for each run() call. For
batched or parallel environments, this is the maximum total number of
steps summed across all environments. Also see below. Default: 0.
|
max_episodes
|
Optional maximum number of episodes for each run() call. For
batched or parallel environments, this is the maximum total number of
episodes summed across all environments. At least one of max_steps or
max_episodes must be provided. If both are set, run() terminates when at
least one of the conditions is satisfied. Default: 0.
|
end_episode_on_boundary
|
This parameter should be False when using
transition observers and be True when using trajectory observers.
|
Raises |
ValueError
|
If both max_steps and max_episodes are None.
|
Attributes |
env
|
|
info_observers
|
|
observers
|
|
policy
|
|
transition_observers
|
|
Methods
run
View source
run(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= ()
) -> Tuple[tf_agents.trajectories.TimeStep
, tf_agents.typing.types.NestedArray
]
Run policy in environment given initial time_step and policy_state.
Args |
time_step
|
The initial time_step.
|
policy_state
|
The initial policy_state.
|
Returns |
A tuple (final time_step, final policy_state).
|