훈련 루프

기계 학습 모델을 훈련할 때 훈련 데이터가 수집(또는 생성)되고, 모델을 통해 배치가 실행되고, 기울기가 얻어지고, 최적화 프로그램을 통해 모델이 업데이트되는 루프가 있는 것이 일반적입니다. 각 훈련 애플리케이션에 대해 자신만의 훈련 루프를 작성할 수 있는 반면, Swift for TensorFlow는 이 프로세스를 단순화할 수 있는 실험적인 훈련 루프 추상화를 제공합니다.

모델 저장소 내의 TrainingLoop 모듈에는 이 실험적이고 일반화된 훈련 루프의 현재 버전이 포함되어 있습니다. 간편한 데이터 수집을 위해 Epochs API를 준수하는 데이터 세트 래퍼와 통합하고, 최적의 성능을 달성하기 위해 가속기 백엔드와 모델, 데이터 세트 및 옵티마이저의 상호 작용을 자동화하는 방식으로 구성됩니다. 콜백을 사용하면 학습 프로세스를 크게 사용자 정의할 수 있습니다.

모델 저장소에 있는 대부분의 이미지 기반 예제는 지도 텍스트 모델 훈련 예제뿐만 아니라 이 훈련 루프 추상화를 사용하도록 변환되었습니다. 그러나 훈련 루프는 모든 기계 학습 모델에 대한 현재 설계에 적합하지 않을 수 있습니다.

TensorFlow의 일반화된 훈련 루프용 Swift 구현은 fastai의 Learner 에 크게 영향을 받았습니다. 디자인에 대한 자세한 내용은 "fastai: 딥 러닝을 위한 계층화된 API" 및 Sylvain Gugger의 프레젠테이션 "Fast.ai - 무한히 사용자 정의 가능한 훈련 루프"를 참조하세요.

용법

ResNet-CIFAR10 예제는 이 훈련 루프를 실제로 사용하는 방법에 대한 좋은 데모를 제공합니다. 먼저 모듈을 가져옵니다.

import TrainingLoop

그런 다음 Device 를 설정하여 가속기 백엔드를 선택합니다. 이 경우 X10 XLA 기반 백엔드를 선택하고 사용 가능한 첫 번째 가속기를 사용합니다.

let device = Device.defaultXLA

다음 단계는 훈련 루프에 사용할 데이터 세트, 모델 및 최적화 프로그램을 구성하는 것입니다.

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

그런 다음 훈련 루프를 설정합니다.

var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

훈련 루프는 사용 중인 데이터세트가 Epochs API를 준수한다고 가정하고 훈련 및 검증에 사용할 데이터세트 내의 분할을 지정할 수 있도록 해줍니다. 모든 손실 함수는 호환 가능한 래퍼(예: softmaxCrossEntropy is here ) 에 배치되면 사용할 수 있습니다.

캡처할 수 있는 현재 측정항목은 다음과 같습니다.

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

마지막으로 훈련을 수행하려면 다음을 호출합니다.

try! trainingLoop.fit(&model, epochs: 10, on: device)

그러면 우리가 지정한 가속기 백엔드를 사용하여 10세대 동안 모델을 훈련하게 됩니다. 훈련 중에 애니메이션 프롬프트를 사용하여 콘솔에 통계가 표시됩니다.

콜백

이 일반화된 훈련 루프의 사용자 정의는 콜백을 사용하여 발생합니다. 이러한 콜백은 루프 내의 다양한 지점에 연결될 수 있습니다.

여러 내장 콜백은 모든 훈련 루프에 추가할 수 있는 기능을 제공합니다. 여기에는 다음이 포함됩니다.

  • 쉼표로 구분된 값(CSV) 파일에 통계 로깅
  • 맞춤형 일정에 따라 학습률 조정
  • TensorBoard를 통한 교육 진행 상황 모니터링 및 그래프 작성

이 외에도 표준 교육 루프에 다양한 추가 기능을 추가하기 위해 사용자 정의 콜백을 직접 만들 수 있습니다.

CSV 로깅

CSVLogger 클래스는 훈련 통계를 쉼표로 구분된 값 형식으로 선택한 파일에 기록하는 콜백을 캡슐화합니다. 이 파일은 epoch , batch 이라는 레이블이 붙은 열과 훈련 루프 내에서 활성화한 모든 측정 항목으로 시작됩니다. 그런 다음 해당 열의 현재 값을 사용하여 각 배치에 대해 하나의 행이 기록됩니다.

훈련 루프에 CSV 로깅을 추가하려면 TrainingLoop 에 대한 callbacks: 매개변수에 제공된 콜백 배열에 다음과 같은 것을 추가하십시오.

try! CSVLogger(path: "file.csv").log

예를 들어 LeNet-MNIST 샘플은 훈련 루프 내에서 이를 사용합니다.

학습률 일정

모델을 훈련할 때 훈련 과정 중에 옵티마이저에 제공되는 학습률을 변경하는 것이 일반적입니다. 이는 시간에 따른 선형 감소만큼 간단할 수도 있고, 복잡한 기능으로 설명되는 워밍업 및 감소 주기만큼 복잡할 수도 있습니다.

learningRateScheduler 콜백은 각각 고유한 모양을 가진 다양한 세그먼트로 구성된 학습 속도 일정을 설명하는 수단을 제공합니다. 이는 각각 함수로 정의된 Shape , 초기 학습률 및 최종 학습률을 갖는 ScheduleSegment 로 구성된 LearningRateSchedule 정의하여 수행됩니다.

예를 들어, BERT-CoLA 샘플은 준비 기간 동안 학습률의 선형 증가와 그 이후의 선형 감소를 사용합니다. 이를 위해 학습률 일정 콜백은 다음과 같이 정의됩니다.

learningRateScheduler(
  schedule: makeSchedule(
    [
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)
    ]
  )
)

두 개의 ScheduleSegment 는 0에서 시작하여 일련의 10개 개별 단계에 걸쳐 peakLearningRate 까지 선형적으로 증가한 다음 이전 단계의 최종 학습 속도에서 시작하고 훈련 프로세스가 끝날 때까지 선형적으로 0으로 감소하는 학습 속도를 정의합니다.

텐서보드 통합

TensorBoard 는 모델 훈련 모니터링, 완료 시 훈련 분석, 훈련 실행 비교를 위한 강력한 시각화 도구입니다. Swift for TensorFlow는 훈련 지표를 기록하는 콜백을 제공하는 모델 저장소의 TensorBoard 모듈을 사용하여 TensorBoard 시각화를 지원합니다.

GPT2-WikiText2 샘플은 모델 훈련에 TensorBoard 로깅을 추가하는 방법을 보여줍니다. 먼저 TensorBoard 모듈을 가져옵니다. 그런 다음 TrainingLoopcallbacks: 배열에 tensorBoardStatisticsLogger() 추가하는 것만큼 간단합니다.

기본적으로 run/tensorboard/stats 디렉터리 내에서 각 훈련 실행을 기록합니다. Tensorboard 내에서 이를 보려면 다음을 실행하세요.

tensorboard --logdir ./run/tensorboard/stats

TensorBoard는 훈련 지표를 볼 수 있는 로컬 서버를 시작해야 합니다. 학습 및 검증 결과는 별도로 표시되어야 하며 각 실행에는 동일한 모델의 여러 실행을 쉽게 비교할 수 있도록 고유한 타임스탬프가 있습니다.

TensorFlow TensorBoard 통합을 위한 Swift의 디자인은 tensorboardX 에서 영감을 받았습니다. TensorBoard 콜백은 적절한 이벤트 및 요약 프로토콜 버퍼를 직접 생성하고 훈련 중에 로그 파일에 기록합니다.

맞춤 콜백

위에 설명된 내장 콜백 외에도 자신만의 콜백을 생성하여 훈련 루프의 기능을 맞춤설정할 수 있습니다. 이러한 콜백은 다음과 유사한 서명이 있는 함수입니다.

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
  if event == .updateStart {
    ...
  }
}

훈련 루프와 관련 상태가 첫 번째 매개변수로 전달됩니다. 콜백이 응답하는 루프의 현재 부분은 event 통해 제공됩니다. 훈련 루프 이벤트에는 다음 상태 중 하나가 있으며, 각 상태는 루프 수명 주기의 서로 다른 지점에 해당합니다.

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

콜백 함수는 위 상태의 모든 조합에서 논리를 활성화하도록 선택할 수 있으며, 이를 통해 다양한 방법으로 훈련 루프에서 데이터를 추출하거나 제어할 수 있습니다.