선행 학습된 모델 사용

이 튜토리얼에서는 TensorFlow.js Layers API를 사용하여 전이 학습을 보여주는 예제 웹 애플리케이션을 탐색합니다. 이 예에서는 사전 훈련된 모델을 로드한 다음 브라우저에서 모델을 다시 훈련합니다.

이 모델은 MNIST 숫자 분류 데이터세트 의 숫자 0-4에 대해 Python으로 사전 학습되었습니다. 브라우저의 재훈련(또는 전이 학습)은 숫자 5-9를 사용합니다. 이 예에서는 사전 훈련된 모델의 처음 여러 계층을 사용하여 전이 학습 중에 새 데이터에서 특징을 추출할 수 있으므로 새 데이터에 대한 더 빠른 훈련이 가능함을 보여줍니다.

이 튜토리얼의 예제 애플리케이션은 온라인으로 제공 되므로 코드를 다운로드하거나 개발 환경을 설정할 필요가 없습니다. 코드를 로컬에서 실행하려면 로컬에서 예제 실행 의 선택적 단계를 완료하세요. 개발 환경을 설정하지 않으려면 예제 탐색 으로 건너뛸 수 있습니다.

예제 코드는 GitHub 에서 사용할 수 있습니다.

(선택 사항) 로컬에서 예제 실행

전제조건

예제 앱을 로컬에서 실행하려면 개발 환경에 다음이 설치되어 있어야 합니다.

예제 앱 설치 및 실행

  1. tfjs-examples 저장소를 복제하거나 다운로드하세요.
  2. mnist-transfer-cnn 디렉터리로 변경합니다.

    cd tfjs-examples/mnist-transfer-cnn
    
  3. 종속성을 설치합니다.

    yarn
    
  4. 개발 서버를 시작합니다.

    yarn run watch
    

예제 살펴보기

예제 앱을 엽니다 . (또는 예제를 로컬에서 실행하는 경우 브라우저에서 http://localhost:1234 로 이동합니다.)

MNIST CNN Transfer Learning 이라는 제목의 페이지가 표시됩니다. 지침에 따라 앱을 사용해 보세요.

다음은 시도해 볼 수 있는 몇 가지 사항입니다.

  • 다양한 훈련 모드를 실험하고 손실과 정확도를 비교하세요.
  • 다양한 비트맵 예제를 선택하고 분류 확률을 검사합니다. 각 비트맵 예제의 숫자는 이미지의 픽셀을 나타내는 회색조 정수 값입니다.
  • 비트맵 정수 값을 편집하고 변경 사항이 분류 확률에 어떤 영향을 미치는지 확인하세요.

코드 살펴보기

예제 웹 앱은 MNIST 데이터 세트의 하위 집합에 대해 사전 훈련된 모델을 로드합니다. 사전 훈련은 Python 프로그램 mnist_transfer_cnn.py 에 정의되어 있습니다. Python 프로그램은 이 튜토리얼의 범위를 벗어나지만 모델 변환 의 예를 보고 싶다면 살펴볼 가치가 있습니다.

index.js 파일에는 데모에 대한 대부분의 학습 코드가 포함되어 있습니다. index.js 브라우저에서 실행되면 설정 함수인 setupMnistTransferCNN 재훈련 및 예측 루틴을 캡슐화하는 MnistTransferCNNPredictor 인스턴스화하고 초기화합니다.

초기화 방법인 MnistTransferCNNPredictor.init 는 모델을 로드하고, 재훈련 데이터를 로드하고, 테스트 데이터를 생성합니다. 모델을 로드하는 라인은 다음과 같습니다.

this.model = await loader.loadHostedPretrainedModel(urls.model);

loader.loadHostedPretrainedModel 의 정의를 살펴보면 tf.loadLayersModel 호출 결과를 반환하는 것을 볼 수 있습니다. Layer 객체로 구성된 모델을 로드하기 위한 TensorFlow.js API입니다.

재훈련 논리는 MnistTransferCNNPredictor.retrainModel 에 정의되어 있습니다. 사용자가 훈련 모드로 피처 레이어 동결을 선택한 경우 기본 모델의 처음 7개 레이어가 동결되고 마지막 5개 레이어만 새 데이터에 대해 훈련됩니다. 사용자가 가중치 재초기화를 선택한 경우 모든 가중치가 재설정되고 앱은 모델을 처음부터 효과적으로 훈련합니다.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

그런 다음 모델이 컴파일된 다음 model.fit() 사용하여 테스트 데이터에 대해 학습 됩니다.

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

model.fit() 매개변수에 대해 자세히 알아보려면 API 설명서를 참조하세요.

새로운 데이터 세트(숫자 5-9)에 대해 교육을 받은 후 모델을 사용하여 예측할 수 있습니다. MnistTransferCNNPredictor.predict 메서드는 model.predict() 사용하여 이를 수행합니다.

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

메모리 누수를 방지하는 데 도움이 되는 tf.tidy 사용에 유의하세요.

자세히 알아보기

이 튜토리얼에서는 TensorFlow.js를 사용하여 브라우저에서 전이 학습을 수행하는 예제 앱을 살펴보았습니다. 사전 훈련된 모델과 전이 학습에 대해 자세히 알아보려면 아래 리소스를 확인하세요.

TensorFlow.js

TensorFlow 코어