tf.contrib.model_pruning.train

View source on GitHub

Wrapper around tf-slim's train function.

Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied synchronously. Otherwise, gradient updates are applied asynchronous.

train_op A Tensor that, when executed, will apply the gradients and return the loss value.
logdir The directory where training logs are written to. If None, model checkpoints and summaries will not be written.
mask_update_op Operation that upon execution updates the weight masks and thresholds.
train_step_fn The function to call in order to execute a single gradient step. The function must have take exactly four arguments: the current session, the train_op Tensor, a global step Tensor and a dictionary.
train_step_kwargs A dictionary which is passed to the train_step_fn. By default, two Boolean, scalar ops called "should_stop" and "should_log" are provided.
log_every_n_steps The frequency, in terms of global steps, that the loss and global step and logged.
graph The graph to pass to the supervisor. If no graph is supplied the default graph is used.
master The address of the tensorflow master.
is_chief Specifies whether or not the training is being run by the primary replica during replica training.
global_step The Tensor representing the global step. If left as None, then slim.variables.get_or_create_global_step() is used.
number_of_steps The max number of gradient steps to take during training, as measured by 'global_step': training will stop if global_step is greater than 'number_of_steps'. If the value is left as None, training proceeds indefinitely.
init_op The initialization operation. If left to its default value, then the session is initialized by calling tf.compat.v1.global_variables_initializer().
init_feed_dict A feed dictionary to use when executing the init_op.
local_init_op The local initialization operation. If left to its default value, then the session is initialized by calling tf.compat.v1.local_variables_initializer() and tf.compat.v1.tables_initializer().
init_fn An optional callable to be executed after init_op is called. The callable must accept one argument, the session being initialized.
ready_op Operation to check if the model is ready to use. If left to its default value, then the session checks for readiness by calling tf.compat.v1.report_uninitialized_variables().
summary_op The summary operation.
save_summaries_secs How often, in seconds, to save summaries.
summary_writer SummaryWriter to use. Can be None to indicate that no summaries should be written. If unset, we create a SummaryWriter.
startup_delay_steps The number of steps to wait for before beginning. Note that this must be 0 if a sync_optimizer is supplied.
saver Saver to save checkpoints. If None, a default one will be created and used.
save_interval_secs How often, in seconds, to save the model to logdir.
sync_optimizer an instance of tf.compat.v1.train.SyncReplicasOptimizer, or a list of them. If the argument is supplied, gradient updates will be synchronous. If left as None, gradient updates will be asynchronous.
session_config An instance of tf.compat.v1.ConfigProto that will be used to configure the Session. If left as None, the default will be used.
trace_every_n_steps produce and save a Timeline in Chrome trace format and add it to the summaries every trace_every_n_steps. If None, no trace information will be produced or saved.

the value of the loss function after training.

ValueError if train_op is empty or if startup_delay_steps is non-zero when sync_optimizer is supplied, if number_of_steps is negative, or if trace_every_n_steps is not None and no logdir is provided.