View source on GitHub |
Common utilities for TF-Agents.
Classes
class AggregatedLosses
: AggregatedLosses(total_loss, weighted, regularization)
class Checkpointer
: Checkpoints training state, policy state, and replay_buffer state.
class EagerPeriodically
: EagerPeriodically performs the ops defined in body
.
class OUProcess
: A zero-mean Ornstein-Uhlenbeck process.
class Periodically
: Periodically performs the ops defined in body
.
Functions
aggregate_losses(...)
: Aggregates and scales per example loss and regularization losses.
assert_members_are_not_overridden(...)
: Asserts public members of base_cls
are not overridden in instance
.
check_matching_networks(...)
: Check that two networks have matching input specs and variables.
check_no_shared_variables(...)
: Checks that there are no shared trainable variables in the two networks.
check_tf1_allowed(...)
: Raises an error if running in TF1 (non-eager) mode and this is disabled.
clip_to_spec(...)
: Clips value to a given bounded tensor spec.
compute_returns(...)
: Compute the return from each index in an episode.
convert_q_logits_to_values(...)
: Converts a set of Q-value logits into Q-values using the provided support.
create_variable(...)
: Create a variable.
deduped_network_variables(...)
: Returns a list of variables in net1 that are not in any other nets.
discounted_future_sum(...)
: Discounted future sum of batch-major values.
discounted_future_sum_masked(...)
: Discounted future sum of batch-major values.
element_wise_squared_loss(...)
entropy(...)
: Computes total entropy of distribution.
extract_shared_variables(...)
: Separates shared variables from the given collections.
function(...)
: Wrapper for tf.function with TF Agents-specific customizations.
function_in_tf1(...)
: Wrapper that returns common.function if using TF1.
generate_tensor_summaries(...)
: Generates various summaries of tensor
such as histogram, max, min, etc.
get_contiguous_sub_episodes(...)
: Computes mask on sub-episodes which includes only contiguous components.
get_episode_mask(...)
: Create a mask that is 0.0 for all final steps, 1.0 elsewhere.
has_eager_been_enabled(...)
: Returns true iff in TF2 or in TF1 with eager execution enabled.
index_with_actions(...)
: Index into q_values using actions.
initialize_uninitialized_variables(...)
: Initialize any pending variables that are uninitialized.
join_scope(...)
: Joins a parent and child scope using /
, checking for empty/none.
load_spec(...)
: Loads a data spec from a file.
log_probability(...)
: Computes log probability of actions given distribution.
maybe_copy_target_network_with_checks(...)
: Copies the network into target if None and checks for shared variables.
ornstein_uhlenbeck_process(...)
: An op for generating noise from a zero-mean Ornstein-Uhlenbeck process.
periodically(...)
: Periodically performs the tensorflow op in body
.
replicate(...)
: Replicates a tensor so as to match the given outer shape.
resource_variables_enabled(...)
safe_has_state(...)
: Safely checks state not in (None, (), [])
.
save_spec(...)
: Saves the given spec nest as a StructProto.
scale_to_spec(...)
: Shapes and scales a batch into the given spec bounds.
set_default_tf_function_parameters(...)
: Generates a decorator that sets default parameters for tf.function
.
shift_values(...)
: Shifts batch-major values in time by some amount.
soft_device_placement(...)
: Context manager for soft device placement, allowing summaries on CPU.
soft_variables_update(...)
: Performs a soft/hard update of variables from the source to the target.
spec_means_and_magnitudes(...)
: Get the center and magnitude of the ranges in action spec.
summarize_tensor_dict(...)
: Generates summaries of all tensors in tensor_dict
.
transpose_batch_time(...)
: Transposes the batch and time dimensions of a Tensor.