View source on GitHub |
Abstract base class for TF Policies.
tf_agents.policies.TFPolicy(
time_step_spec: tf_agents.trajectories.TimeStep
,
action_spec: tf_agents.typing.types.NestedTensorSpec
,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec
= (),
info_spec: tf_agents.typing.types.NestedTensorSpec
= (),
clip: bool = True,
emit_log_probability: bool = False,
automatic_state_reset: bool = True,
observation_and_action_constraint_splitter: Optional[types.Splitter] = None,
validate_args: bool = True,
name: Optional[Text] = None
)
Used in the notebooks
Used in the tutorials |
---|
The Policy represents a mapping from time_steps
recieved from the
environment to actions
that can be applied to the environment.
Agents expose two policies. A policy
meant for deployment and evaluation,
and a collect_policy
for collecting data from the environment. The
collect_policy
is usually stochastic for exploring the environment better
and may log auxilliary information such as log probabilities required for
training as well. Policy
objects can also be created directly by the users
without using an Agent
.
The main methods of TFPolicy are:
action
: Maps atime_step
from the environment to an action.distribution
: Maps atime_step
to a distribution over actions.get_initial_state
: Generates the initial state for stateful policies, e.g. RNN/LSTM policies.
Example usage:
env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy
policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()
while not time_step.is_last():
policy_step = policy.action(time_step, policy_state)
time_step = env.step(policy_step.action)
policy_state = policy_step.state
# policy_step.info may contain side info for logging, such as action log
# probabilities.
Policies can be saved to disk as SavedModels (see policy_saver.py and policy_loader.py) or as TF Checkpoints.
A PyTFEagerPolicy
can be used to wrap a TFPolicy
so that it works with
PyEnvironment
s.
For researchers, and those developing new Policies, the TFPolicy
base class
constructor also accept a validate_args
parameter. If False
, this
disables all spec structure, dtype, and shape checks in the public methods of
these classes. It allows algorithm developers to iterate and try different
input and output structures without worrying about overly restrictive
requirements, or input and output states being in a certain format. However,
disabling argument validation can make it very hard to identify structural
input or algorithmic errors; and should not be done for final, or
production-ready, Policies. In addition to having implementations that may
disagree with specs, this mean that the resulting Policy may no longer
interact well with other parts of TF-Agents. Examples include impedance
mismatches with Actor/Learner APIs, replay buffers, and the model export
functionality in `PolicySaver.
Methods
action
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedTensor
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. |
seed
|
Seed to use if action performs sampling (optional). |
Returns | |
---|---|
A PolicyStep named tuple containing:
action : An action Tensor matching the action_spec .
state : A policy state tensor to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
Raises | |
---|---|
RuntimeError
|
If subclass init didn't call super().init.
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec , policy_state_spec ,
or policy_step_spec .
|
distribution
distribution(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedTensor
= ()
) -> tf_agents.trajectories.PolicyStep
Generates the distribution over next actions given the time_step.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. |
Returns | |
---|---|
A PolicyStep named tuple containing:
|
Raises | |
---|---|
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec , policy_state_spec ,
or policy_step_spec .
|
get_initial_state
get_initial_state(
batch_size: Optional[types.Int]
) -> tf_agents.typing.types.NestedTensor
Returns an initial state usable by the policy.
Args | |
---|---|
batch_size
|
Tensor or constant: size of the batch dimension. Can be None in which case no dimensions gets added. |
Returns | |
---|---|
A nested object of type policy_state containing properly
initialized Tensors.
|
update
update(
policy,
tau: float = 1.0,
tau_non_trainable: Optional[float] = None,
sort_variables_by_name: bool = False
) -> tf.Operation
Update the current policy with another policy.
This would include copying the variables from the other policy.
Args | |
---|---|
policy
|
Another policy it can update from. |
tau
|
A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard update. This is used for trainable variables. |
tau_non_trainable
|
A float scalar in [0, 1] for non_trainable variables. If None, will copy from tau. |
sort_variables_by_name
|
A bool, when True would sort the variables by name before doing the update. |
Returns | |
---|---|
An TF op to do the update. |