View source on GitHub |
Exposes a numpy API for saved_model policies in Eager mode.
Inherits From: PyTFEagerPolicyBase
, PyPolicy
tf_agents.policies.SavedModelPyTFEagerPolicy(
model_path: Text,
time_step_spec: Optional[tf_agents.trajectories.TimeStep
] = None,
action_spec: Optional[tf_agents.typing.types.DistributionSpecV2
] = None,
policy_state_spec: tf_agents.typing.types.NestedTensorSpec
= (),
info_spec: tf_agents.typing.types.NestedTensorSpec
= (),
load_specs_from_pbtxt: bool = False,
use_tf_function: bool = False,
batch_time_steps=True
)
Used in the notebooks
Used in the tutorials |
---|
Methods
action
action(
time_step: tf_agents.trajectories.TimeStep
,
policy_state: tf_agents.typing.types.NestedArray
= (),
seed: Optional[types.Seed] = None
) -> tf_agents.trajectories.PolicyStep
Generates next action given the time_step and policy_state.
Args | |
---|---|
time_step
|
A TimeStep tuple corresponding to time_step_spec() .
|
policy_state
|
An optional previous policy_state. |
seed
|
Seed to use if action uses sampling (optional). |
Returns | |
---|---|
A PolicyStep named tuple containing:
action : A nest of action Arrays matching the action_spec() .
state : A nest of policy states to be fed into the next call to action.
info : Optional side information such as action log probabilities.
|
get_initial_state
get_initial_state(
batch_size: Optional[int] = None
) -> tf_agents.typing.types.NestedArray
Returns an initial state usable by the policy.
Args | |
---|---|
batch_size
|
An optional batch size. |
Returns | |
---|---|
An initial policy state. |
get_metadata
get_metadata()
Returns the metadata of the saved model.
get_train_step
get_train_step() -> tf_agents.typing.types.Int
Returns the training global step of the saved model.
get_train_step_from_last_restored_checkpoint_path
get_train_step_from_last_restored_checkpoint_path() -> Optional[int]
Returns the training step of the restored checkpoint.
update_from_checkpoint
update_from_checkpoint(
checkpoint_path: Text
)
Allows users to update saved_model variables directly from a checkpoint.
checkpoint_path
is a path that was passed to either PolicySaver.save()
or PolicySaver.save_checkpoint()
. The policy looks for set of checkpoint
files with the file prefix `
Args | |
---|---|
checkpoint_path
|
Path to the checkpoint to restore and use to udpate this policy. |
variables
variables()