tff.learning.optimizers.build_adamw

Returns a tff.learning.optimizers.Optimizer for AdamW.

The AdamW optimizer is based on Decoupled Weight Decay Regularization

The update rule given learning rate lr, epsilon eps, accumulator acc, preconditioner s, weigh decay lambda, iteration t, weights w and gradients g is:

acc = beta_1 * acc + (1 - beta_1) * g
s = beta_2 * s + (1 - beta_2) * g**2
normalization = sqrt(1 - beta_2**t) / (1 - beta_1**t)
w = w - lr * (normalization * acc / (sqrt(s) + eps) + lambda * w)

learning_rate A positive float for learning rate.
beta_1 A float between 0.0 and 1.0 for the decay used to track the previous gradients.
beta_2 A float between 0.0 and 1.0 for the decay used to track the magnitude (second moment) of previous gradients.
epsilon A small non-negative float, used to maintain numerical stability.
weight_decay A non-negative float, governing the amount of weight decay. When set to 0, this recovers Adam.