public class Adam<Model: Differentiable>: Optimizer
where
Model.TangentVector: VectorProtocol & PointwiseMultiplicative
& ElementaryFunctions & KeyPathIterable,
Model.TangentVector.VectorSpaceScalar == Float
Optimiseur Adam.
Implémente l'algorithme d'optimisation Adam. Adam est une méthode de descente de gradient stochastique qui calcule les taux d'apprentissage adaptatifs individuels pour différents paramètres à partir d'estimations des moments de premier et de second ordre des gradients.
Référence : « Adam : Une méthode d'optimisation stochastique » (Kingma et Ba, 2014).
Exemples :
- Former un agent d'apprentissage par renforcement simple :
...
// Instantiate an agent's policy - approximated by the neural network (`net`) after defining it
in advance.
var net = Net(observationSize: Int(observationSize), hiddenSize: hiddenSize, actionCount: actionCount)
// Define the Adam optimizer for the network with a learning rate set to 0.01.
let optimizer = Adam(for: net, learningRate: 0.01)
...
// Begin training the agent (over a certain number of episodes).
while true {
...
// Implementing the gradient descent with the Adam optimizer:
// Define the gradients (use withLearningPhase to call a closure under a learning phase).
let gradients = withLearningPhase(.training) {
TensorFlow.gradient(at: net) { net -> Tensor<Float> in
// Return a softmax (loss) function
return loss = softmaxCrossEntropy(logits: net(input), probabilities: target)
}
}
// Update the differentiable variables of the network (`net`) along the gradients with the Adam
optimizer.
optimizer.update(&net, along: gradients)
...
}
}
- Former un réseau contradictoire génératif (GAN) :
...
// Instantiate the generator and the discriminator networks after defining them.
var generator = Generator()
var discriminator = Discriminator()
// Define the Adam optimizers for each network with a learning rate set to 2e-4 and beta1 - to 0.5.
let adamOptimizerG = Adam(for: generator, learningRate: 2e-4, beta1: 0.5)
let adamOptimizerD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)
...
Start the training loop over a certain number of epochs (`epochCount`).
for epoch in 1...epochCount {
// Start the training phase.
...
for batch in trainingShuffled.batched(batchSize) {
// Implementing the gradient descent with the Adam optimizer:
// 1) Update the generator.
...
let 𝛁generator = TensorFlow.gradient(at: generator) { generator -> Tensor<Float> in
...
return loss
}
// Update the differentiable variables of the generator along the gradients (`𝛁generator`)
// with the Adam optimizer.
adamOptimizerG.update(&generator, along: 𝛁generator)
// 2) Update the discriminator.
...
let 𝛁discriminator = TensorFlow.gradient(at: discriminator) { discriminator -> Tensor<Float> in
...
return loss
}
// Update the differentiable variables of the discriminator along the gradients (`𝛁discriminator`)
// with the Adam optimizer.
adamOptimizerD.update(&discriminator, along: 𝛁discriminator)
}
}
Déclaration
public typealias Model = Model
Le taux d'apprentissage.
Déclaration
public var learningRate: Float
Un coefficient utilisé pour calculer les premiers instants des gradients.
Déclaration
public var beta1: Float
Un coefficient utilisé pour calculer les seconds moments des gradients.
Déclaration
public var beta2: Float
Un petit scalaire ajouté au dénominateur pour améliorer la stabilité numérique.
Déclaration
public var epsilon: Float
Le taux d’apprentissage diminue.
Déclaration
public var decay: Float
L'étape actuelle.
Déclaration
public var step: Int
Les premiers instants des poids.
Déclaration
public var firstMoments: Model.TangentVector
Les seconds instants des poids.
Déclaration
public var secondMoments: Model.TangentVector
Déclaration
public init( for model: __shared Model, learningRate: Float = 1e-3, beta1: Float = 0.9, beta2: Float = 0.999, epsilon: Float = 1e-8, decay: Float = 0 )
Paramètres
learningRate
Le taux d'apprentissage. La valeur par défaut est
1e-3
.beta1
Le taux de décroissance exponentielle pour les estimations du 1er instant. La valeur par défaut est
0.9
.beta2
Le taux de décroissance exponentielle pour les estimations du 2ème moment. La valeur par défaut est
0.999
.epsilon
Un petit scalaire ajouté au dénominateur pour améliorer la stabilité numérique. La valeur par défaut est
1e-8
.decay
Le taux d’apprentissage diminue. La valeur par défaut est
0
.Déclaration
public required init(copying other: Adam, to device: Device)