Module: tf_agents.utils.common

Common utilities for TF-Agents.

Classes

class AggregatedLosses: AggregatedLosses(total_loss, weighted, regularization)

class Checkpointer: Checkpoints training state, policy state, and replay_buffer state.

class EagerPeriodically: EagerPeriodically performs the ops defined in body.

class OUProcess: A zero-mean Ornstein-Uhlenbeck process.

class Periodically: Periodically performs the ops defined in body.

Functions

aggregate_losses(...): Aggregates and scales per example loss and regularization losses.

assert_members_are_not_overridden(...): Asserts public members of base_cls are not overridden in instance.

check_matching_networks(...): Check that two networks have matching input specs and variables.

check_no_shared_variables(...): Checks that there are no shared trainable variables in the two networks.

check_tf1_allowed(...): Raises an error if running in TF1 (non-eager) mode and this is disabled.

clip_to_spec(...): Clips value to a given bounded tensor spec.

compute_returns(...): Compute the return from each index in an episode.

convert_q_logits_to_values(...): Converts a set of Q-value logits into Q-values using the provided support.

create_variable(...): Create a variable.

deduped_network_variables(...): Returns a list of variables in net1 that are not in any other nets.

discounted_future_sum(...): Discounted future sum of batch-major values.

discounted_future_sum_masked(...): Discounted future sum of batch-major values.

element_wise_huber_loss(...)

element_wise_squared_loss(...)

entropy(...): Computes total entropy of distribution.

extract_shared_variables(...): Separates shared variables from the given collections.

function(...): Wrapper for tf.function with TF Agents-specific customizations.

function_in_tf1(...): Wrapper that returns common.function if using TF1.

generate_tensor_summaries(...): Generates various summaries of tensor such as histogram, max, min, etc.

get_contiguous_sub_episodes(...): Computes mask on sub-episodes which includes only contiguous components.

get_episode_mask(...): Create a mask that is 0.0 for all final steps, 1.0 elsewhere.

has_eager_been_enabled(...): Returns true iff in TF2 or in TF1 with eager execution enabled.

in_legacy_tf1(...)

index_with_actions(...): Index into q_values using actions.

initialize_uninitialized_variables(...): Initialize any pending variables that are uninitialized.

join_scope(...): Joins a parent and child scope using /, checking for empty/none.

load_spec(...): Loads a data spec from a file.

log_probability(...): Computes log probability of actions given distribution.

maybe_copy_target_network_with_checks(...): Copies the network into target if None and checks for shared variables.

ornstein_uhlenbeck_process(...): An op for generating noise from a zero-mean Ornstein-Uhlenbeck process.

periodically(...): Periodically performs the tensorflow op in body.

replicate(...): Replicates a tensor so as to match the given outer shape.

resource_variables_enabled(...)

safe_has_state(...): Safely checks state not in (None, (), []).

save_spec(...): Saves the given spec nest as a StructProto.

scale_to_spec(...): Shapes and scales a batch into the given spec bounds.

set_default_tf_function_parameters(...): Generates a decorator that sets default parameters for tf.function.

shift_values(...): Shifts batch-major values in time by some amount.

soft_device_placement(...): Context manager for soft device placement, allowing summaries on CPU.

soft_variables_update(...): Performs a soft/hard update of variables from the source to the target.

spec_means_and_magnitudes(...): Get the center and magnitude of the ranges in action spec.

summarize_scalar_dict(...)

summarize_tensor_dict(...): Generates summaries of all tensors in tensor_dict.

transpose_batch_time(...): Transposes the batch and time dimensions of a Tensor.

MISSING_RESOURCE_VARIABLES_ERROR ('\n' 'Resource variables are not enabled. Please enable them by adding the ' 'following\n' 'code to your main() method:\n' ' tf.compat.v1.enable_resource_variables()\n' 'For unit tests, subclass \`tf_agents.utils.test_utils.TestCase\`.\n')
absolute_import Instance of __future__._Feature
division Instance of __future__._Feature
print_function Instance of __future__._Feature
tf_agents_gauge Instance of tensorflow.python.eager.monitoring.BoolGauge