View source on GitHub |
Specifies when to prune layer and the sparsity(%) at each training step.
PruningSchedule controls pruning during training by notifying at each step whether the layer's weights should be pruned or not, and the sparsity(%) at which they should be pruned.
It can be invoked as a callable
by providing the training step
Tensor. It
returns a tuple of bool and float tensors.
should_prune, sparsity = pruning_schedule(step)
You can inherit this class to write your own custom pruning schedule.
Methods
from_config
@classmethod
from_config( config )
Instantiates a PruningSchedule
from its config.
Args | |
---|---|
config
|
Output of get_config() .
|
Returns | |
---|---|
A PruningSchedule instance.
|
get_config
@abc.abstractmethod
get_config()
__call__
@abc.abstractmethod
__call__( step )
Returns the sparsity(%) to be applied.
If the returned sparsity(%) is 0, pruning is ignored for the step.
Args | |
---|---|
step
|
Current step in graph execution. |
Returns | |
---|---|
Sparsity (%) that should be applied to the weights for the step. |