SGD

public class SGD<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

Un ottimizzatore di discesa del gradiente stocastico (SGD).

Implementa l'algoritmo di discesa del gradiente stocastico con supporto per lo slancio, il decadimento del tasso di apprendimento e lo slancio di Nesterov. Il momento e il momento di Nesterov (noto anche come metodo del gradiente accelerato di Nesterov) sono metodi di ottimizzazione del primo ordine che possono migliorare la velocità di allenamento e il tasso di convergenza della discesa del gradiente.

Riferimenti:

  • Dichiarazione

    public typealias Model = Model
  • Il tasso di apprendimento.

    Dichiarazione

    public var learningRate: Float
  • Il fattore slancio. Accelera la discesa del gradiente stocastico nella direzione rilevante e smorza le oscillazioni.

    Dichiarazione

    public var momentum: Float
  • Il decadimento del tasso di apprendimento.

    Dichiarazione

    public var decay: Float
  • Usa lo slancio di Nesterov se è vero.

    Dichiarazione

    public var nesterov: Bool
  • Lo stato di velocità del modello.

    Dichiarazione

    public var velocity: Model.TangentVector
  • L'insieme dei passi compiuti.

    Dichiarazione

    public var step: Int
  • Crea un'istanza per model .

    Dichiarazione

    public init(
      for model: __shared Model,
      learningRate: Float = 0.01,
      momentum: Float = 0,
      decay: Float = 0,
      nesterov: Bool = false
    )

    Parametri

    learningRate

    Il tasso di apprendimento. Il valore predefinito è 0.01 .

    momentum

    Il fattore di quantità di moto che accelera la discesa del gradiente stocastico nella direzione rilevante e smorza le oscillazioni. Il valore predefinito è 0 .

    decay

    Il decadimento del tasso di apprendimento. Il valore predefinito è 0 .

    nesterov

    Usa lo slancio di Nesterov se e solo se è true . Il valore predefinito è true .

  • Dichiarazione

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • Dichiarazione

    public required init(copying other: SGD, to device: Device)