SGD

public class SGD<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

Un optimizador de descenso de gradiente estocástico (SGD).

Implementa el algoritmo de descenso de gradiente estocástico con soporte para impulso, caída de la tasa de aprendizaje e impulso de Nesterov. El impulso y el impulso de Nesterov (también conocido como método de gradiente acelerado de Nesterov) son métodos de optimización de primer orden que pueden mejorar la velocidad de entrenamiento y la tasa de convergencia del descenso de gradiente.

Referencias:

  • Declaración

    public typealias Model = Model
  • La tasa de aprendizaje.

    Declaración

    public var learningRate: Float
  • El factor impulso. Acelera el descenso del gradiente estocástico en la dirección relevante y amortigua las oscilaciones.

    Declaración

    public var momentum: Float
  • La tasa de aprendizaje decae.

    Declaración

    public var decay: Float
  • Utilice el impulso de Nesterov si es cierto.

    Declaración

    public var nesterov: Bool
  • El estado de velocidad del modelo.

    Declaración

    public var velocity: Model.TangentVector
  • El conjunto de pasos dados.

    Declaración

    public var step: Int
  • Crea una instancia para model .

    Declaración

    public init(
      for model: __shared Model,
      learningRate: Float = 0.01,
      momentum: Float = 0,
      decay: Float = 0,
      nesterov: Bool = false
    )

    Parámetros

    learningRate

    La tasa de aprendizaje. El valor predeterminado es 0.01 .

    momentum

    El factor de impulso que acelera el descenso del gradiente estocástico en la dirección relevante y amortigua las oscilaciones. El valor predeterminado es 0 .

    decay

    La tasa de aprendizaje decae. El valor predeterminado es 0 .

    nesterov

    Utilice el impulso de Nesterov si es true . El valor predeterminado es true .

  • Declaración

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • Declaración

    public required init(copying other: SGD, to device: Device)