View source on GitHub |
Adds adversarial regularization to a tf.estimator.Estimator
.
nsl.estimator.add_adversarial_regularization(
estimator, optimizer_fn=None, adv_config=None
)
The returned estimator will include the adversarial loss as a regularization
term in its training objective, and will be trained using the optimizer
provided by optimizer_fn
. optimizer_fn
(along with the hyperparameters)
should be set to the same one used in the base estimator
.
If optimizer_fn
is not set, a default optimizer tf.train.AdagradOptimizer
with learning_rate=0.05
will be used.
Args | |
---|---|
estimator
|
A tf.estimator.Estimator object, the base model.
|
optimizer_fn
|
A function that accepts no arguments and returns an instance
of tf.train.Optimizer . This optimizer (instead of the one used in
estimator ) will be used to train the model. If not specified, default to
tf.train.AdagradOptimizer with learning_rate=0.05 .
|
adv_config
|
An instance of nsl.configs.AdvRegConfig that specifies various
hyperparameters for adversarial regularization.
|
Returns | |
---|---|
A modified tf.estimator.Estimator object with adversarial regularization
incorporated into its loss.
|