View source on GitHub |
An abstract class defining the API required for training.
orbit.AbstractTrainer(
name=None
)
Attributes | |
---|---|
name
|
Returns the name of this module as passed or determined in the ctor. |
name_scope
|
Returns a tf.name_scope instance for this class.
|
non_trainable_variables
|
Sequence of non-trainable variables owned by this module and its submodules. |
submodules
|
Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as properties of modules which are properties of this module (and so on).
|
trainable_variables
|
Sequence of trainable variables owned by this module and its submodules. |
variables
|
Sequence of variables owned by this module and its submodules. |
Methods
train
@abc.abstractmethod
train( num_steps: tf.Tensor ) -> Optional[Output]
Implements num_steps
steps of training.
This method will be called by the Controller
to perform the "inner loop"
of training. This inner loop amortizes the cost of bookkeeping associated
with checkpointing, evaluation, and writing summaries. Additionally, the
inner loop can be implemented (if desired) using TensorFlow's looping
constructs (e.g. a for
loop over a tf.range
inside a tf.function
),
which can be necessary for getting optimal performance when running on TPU.
For cases that don't require peak performance, a simple Python loop can be
used instead for simplicity.
Args | |
---|---|
num_steps
|
The number of training steps to run. Note that it is up to the model what constitutes a "step", which may involve more than one update to model parameters (e.g., if training a GAN). |
Returns | |
---|---|
Either None , or a dictionary mapping names to Tensor s or NumPy values.
If a dictionary is returned, it will be written to logs and as TensorBoard
summaries. The dictionary may also be nested, which will generate a
hierarchy of summary directories.
|
with_name_scope
@classmethod
with_name_scope( method )
Decorator to automatically enter the module name scope.
class MyModule(tf.Module):
@tf.Module.with_name_scope
def __call__(self, x):
if not hasattr(self, 'w'):
self.w = tf.Variable(tf.random.normal([x.shape[1], 3]))
return tf.matmul(x, self.w)
Using the above module would produce tf.Variable
s and tf.Tensor
s whose
names included the module name:
mod = MyModule()
mod(tf.ones([1, 2]))
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=..., dtype=float32)>
mod.w
<tf.Variable 'my_module/Variable:0' shape=(2, 3) dtype=float32,
numpy=..., dtype=float32)>
Args | |
---|---|
method
|
The method to wrap. |
Returns | |
---|---|
The original method wrapped such that it enters the module's name scope. |