View source on GitHub |
Replaces tf.Variable
initializers so they load from a checkpoint file.
tf.compat.v1.train.init_from_checkpoint(
ckpt_dir_or_file, assignment_map
)
Migrate to TF2
tf.compat.v1.train.init_from_checkpoint
is not recommended for restoring
variable values in TF2.
To restore checkpoints in TF2, please use
tf.keras.Model.load_weights
or tf.train.Checkpoint.restore
. These APIs use
use an object-based method of checkpointing, while
tf.compat.v1.init_from_checkpoint
relies on a more-fragile variable-name
based method of checkpointing. There is no object-based equivalent of
init_from_checkpoint
in TF2.
Please re-write your checkpoints immediately using the object-based APIs, see migration guide for more details.
You can load a name-based checkpoint written by tf.compat.v1.train.Saver
using tf.train.Checkpoint.restore
or tf.keras.Model.load_weights
. However,
you may have to change the names of the variables in your model to match the
variable names in the name-based checkpoint, which can be viewed with
tf.train.list_variables(path)
.
Another option is to create an assignment_map
that maps the name of the
variables in the name-based checkpoint to the variables in your model, eg:
{
'sequential/dense/bias': model.variables[0],
'sequential/dense/kernel': model.variables[1]
}
and use tf.compat.v1.train.init_from_checkpoint(path, assignment_map)
to
restore the name-based checkpoint.
After restoring, re-encode your checkpoint using tf.train.Checkpoint.save
or tf.keras.Model.save_weights
.
Description
Values are not loaded immediately, but when the initializer is run
(typically by running a tf.compat.v1.global_variables_initializer
op).
Assignment map supports following syntax:
'checkpoint_scope_name/': 'scope_name/'
- will load all variables in currentscope_name
fromcheckpoint_scope_name
with matching tensor names.'checkpoint_scope_name/some_other_variable': 'scope_name/variable_name'
- will initializescope_name/variable_name
variable fromcheckpoint_scope_name/some_other_variable
.'scope_variable_name': variable
- will initialize giventf.Variable
object with tensor 'scope_variable_name' from the checkpoint.'scope_variable_name': list(variable)
- will initialize list of partitioned variables with tensor 'scope_variable_name' from the checkpoint.'/': 'scope_name/'
- will load all variables in currentscope_name
from checkpoint's root (e.g. no scope).
Supports loading into partitioned variables, which are represented as
'<variable>/part_<part #>'
.
Assignment map can be a dict, or a list of pairs. The latter is necessary to initialize multiple variables in the current graph from the same variable in the checkpoint.
Example:
# Say, '/tmp/model.ckpt' has the following tensors:
# -- name='old_scope_1/var1', shape=[20, 2]
# -- name='old_scope_1/var2', shape=[50, 4]
# -- name='old_scope_2/var3', shape=[100, 100]
# Create new model's variables
with tf.compat.v1.variable_scope('new_scope_1'):
var1 = tf.compat.v1.get_variable('var1', shape=[20, 2],
initializer=tf.compat.v1.zeros_initializer())
with tf.compat.v1.variable_scope('new_scope_2'):
var2 = tf.compat.v1.get_variable('var2', shape=[50, 4],
initializer=tf.compat.v1.zeros_initializer())
# Partition into 5 variables along the first axis.
var3 = tf.compat.v1.get_variable(name='var3', shape=[100, 100],
initializer=tf.compat.v1.zeros_initializer(),
partitioner=lambda shape, dtype: [5, 1])
# Initialize all variables in `new_scope_1` from `old_scope_1`.
init_from_checkpoint('/tmp/model.ckpt', {'old_scope_1/': 'new_scope_1/'})
# Use names to specify which variables to initialize from checkpoint.
init_from_checkpoint('/tmp/model.ckpt',
{'old_scope_1/var1': 'new_scope_1/var1',
'old_scope_1/var2': 'new_scope_2/var2'})
# Or use tf.Variable objects to identify what to initialize.
init_from_checkpoint('/tmp/model.ckpt',
{'old_scope_1/var1': var1,
'old_scope_1/var2': var2})
# Initialize partitioned variables using variable's name
init_from_checkpoint('/tmp/model.ckpt',
{'old_scope_2/var3': 'new_scope_2/var3'})
# Or specify the list of tf.Variable objects.
init_from_checkpoint('/tmp/model.ckpt',
{'old_scope_2/var3': var3._get_variable_list()})
Raises | |
---|---|
ValueError
|
If missing variables in current graph, or if missing checkpoints or tensors in checkpoints. |