View source on GitHub |
Triggers saves policy checkpoints an agent's policy.
Inherits From: IntervalTrigger
tf_agents.train.triggers.PolicySavedModelTrigger(
saved_model_dir: Text,
agent: tf_agents.agents.TFAgent
,
train_step: tf.Variable,
interval: int,
async_saving: bool = False,
metadata_metrics: Optional[Mapping[Text, py_metric.PyMetric]] = None,
start: int = 0,
extra_concrete_functions: Optional[Sequence[Tuple[str, policy_saver.def_function.Function]]] = None,
batch_size: Optional[int] = None,
use_nest_path_signatures: bool = True,
save_greedy_policy=True,
save_collect_policy=True,
input_fn_and_spec: Optional[tf_agents.policies.policy_saver.InputFnAndSpecType
] = None
)
Used in the notebooks
Used in the tutorials |
---|
On construction this trigger will generate a saved_model for a:
greedy_policy
, a collect_policy
, and a raw_policy
. When triggered a
checkpoint will be saved which can be used to updated any of the saved_model
policies.
Args | |
---|---|
saved_model_dir
|
Base dir where checkpoints will be saved. |
agent
|
Agent to extract policies from. |
train_step
|
tf.Variable which keeps track of the number of train steps.
|
interval
|
How often, in train_steps, the trigger will save. Note that as
long as the >= interval number of steps have passed since the last
trigger, the event gets triggered. The current value is not necessarily
interval steps away from the last triggered value.
|
async_saving
|
If True saving will be done asynchronously in a separate thread. Note if this is on the variable values in the saved checkpoints/models are not deterministic. |
metadata_metrics
|
A dictionary of metrics, whose result() method returns
a scalar to be saved along with the policy. Currently only supported
when async_saving is False.
|
start
|
Initial value for the trigger passed directly to the base class. It helps control from which train step the weigts of the model are saved. |
extra_concrete_functions
|
Optional sequence of extra concrete functions to register in the policy savers. The sequence should consist of tuples with string name for the function and the tf.function to register. Note this does not support adding extra assets. |
batch_size
|
The number of batch entries the policy will process at a time.
This must be either None (unknown batch size) or a python integer.
|
use_nest_path_signatures
|
SavedModel spec signatures will be created based on the sructure of the specs. Otherwise all specs must have unique names. |
save_greedy_policy
|
Disable when an agent's policy distribution method does not support mode. |
save_collect_policy
|
Disable when not saving collect policy. |
input_fn_and_spec
|
A (input_fn, tensor_spec) tuple where input_fn is a
function that takes inputs according to tensor_spec and converts them to
the (time_step, policy_state) tuple that is used as the input to the
action_fn. When input_fn_and_spec is set, tensor_spec is the input
for the action signature. When input_fn_and_spec is None , the action
signature takes as input (time_step, policy_state) .
|
Methods
reset
reset() -> None
Resets the trigger interval.
set_start
set_start(
start: int
) -> None
__call__
__call__(
value: int, force_trigger: bool = False
) -> None
Maybe trigger the event based on the interval.
Args | |
---|---|
value
|
the value for triggering. |
force_trigger
|
If True, the trigger will be forced triggered unless the
last trigger value is equal to value .
|