이 튜토리얼에서는 TensorFlow.js Layers API를 사용하여 전이 학습을 보여주는 예제 웹 애플리케이션을 탐색합니다. 이 예에서는 사전 훈련된 모델을 로드한 다음 브라우저에서 모델을 다시 훈련합니다.
이 모델은 MNIST 숫자 분류 데이터세트 의 숫자 0-4에 대해 Python으로 사전 학습되었습니다. 브라우저의 재훈련(또는 전이 학습)은 숫자 5-9를 사용합니다. 이 예에서는 사전 훈련된 모델의 처음 여러 계층을 사용하여 전이 학습 중에 새 데이터에서 특징을 추출할 수 있으므로 새 데이터에 대한 더 빠른 훈련이 가능함을 보여줍니다.
이 튜토리얼의 예제 애플리케이션은 온라인으로 제공 되므로 코드를 다운로드하거나 개발 환경을 설정할 필요가 없습니다. 코드를 로컬에서 실행하려면 로컬에서 예제 실행 의 선택적 단계를 완료하세요. 개발 환경을 설정하지 않으려면 예제 탐색 으로 건너뛸 수 있습니다.
예제 코드는 GitHub 에서 사용할 수 있습니다.
(선택 사항) 로컬에서 예제 실행
전제조건
예제 앱을 로컬에서 실행하려면 개발 환경에 다음이 설치되어 있어야 합니다.
예제 앱 설치 및 실행
-
tfjs-examples
저장소를 복제하거나 다운로드하세요. mnist-transfer-cnn
디렉터리로 변경합니다.cd tfjs-examples/mnist-transfer-cnn
종속성을 설치합니다.
yarn
개발 서버를 시작합니다.
yarn run watch
예제 살펴보기
예제 앱을 엽니다 . (또는 예제를 로컬에서 실행하는 경우 브라우저에서 http://localhost:1234
로 이동합니다.)
MNIST CNN Transfer Learning 이라는 제목의 페이지가 표시됩니다. 지침에 따라 앱을 사용해 보세요.
다음은 시도해 볼 수 있는 몇 가지 사항입니다.
- 다양한 훈련 모드를 실험하고 손실과 정확도를 비교하세요.
- 다양한 비트맵 예제를 선택하고 분류 확률을 검사합니다. 각 비트맵 예제의 숫자는 이미지의 픽셀을 나타내는 회색조 정수 값입니다.
- 비트맵 정수 값을 편집하고 변경 사항이 분류 확률에 어떤 영향을 미치는지 확인하세요.
코드 살펴보기
예제 웹 앱은 MNIST 데이터 세트의 하위 집합에 대해 사전 훈련된 모델을 로드합니다. 사전 훈련은 Python 프로그램 mnist_transfer_cnn.py
에 정의되어 있습니다. Python 프로그램은 이 튜토리얼의 범위를 벗어나지만 모델 변환 의 예를 보고 싶다면 살펴볼 가치가 있습니다.
index.js
파일에는 데모에 대한 대부분의 학습 코드가 포함되어 있습니다. index.js
브라우저에서 실행되면 설정 함수인 setupMnistTransferCNN
재훈련 및 예측 루틴을 캡슐화하는 MnistTransferCNNPredictor
인스턴스화하고 초기화합니다.
초기화 방법인 MnistTransferCNNPredictor.init
는 모델을 로드하고, 재훈련 데이터를 로드하고, 테스트 데이터를 생성합니다. 모델을 로드하는 라인은 다음과 같습니다.
this.model = await loader.loadHostedPretrainedModel(urls.model);
loader.loadHostedPretrainedModel
의 정의를 살펴보면 tf.loadLayersModel
호출 결과를 반환하는 것을 볼 수 있습니다. Layer 객체로 구성된 모델을 로드하기 위한 TensorFlow.js API입니다.
재훈련 논리는 MnistTransferCNNPredictor.retrainModel
에 정의되어 있습니다. 사용자가 훈련 모드로 피처 레이어 동결을 선택한 경우 기본 모델의 처음 7개 레이어가 동결되고 마지막 5개 레이어만 새 데이터에 대해 훈련됩니다. 사용자가 가중치 재초기화를 선택한 경우 모든 가중치가 재설정되고 앱은 모델을 처음부터 효과적으로 훈련합니다.
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
그런 다음 모델이 컴파일된 다음 model.fit()
사용하여 테스트 데이터에 대해 학습 됩니다.
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
model.fit()
매개변수에 대해 자세히 알아보려면 API 설명서를 참조하세요.
새로운 데이터 세트(숫자 5-9)에 대해 교육을 받은 후 모델을 사용하여 예측할 수 있습니다. MnistTransferCNNPredictor.predict
메서드는 model.predict()
사용하여 이를 수행합니다.
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
메모리 누수를 방지하는 데 도움이 되는 tf.tidy
사용에 유의하세요.
자세히 알아보기
이 튜토리얼에서는 TensorFlow.js를 사용하여 브라우저에서 전이 학습을 수행하는 예제 앱을 살펴보았습니다. 사전 훈련된 모델과 전이 학습에 대해 자세히 알아보려면 아래 리소스를 확인하세요.
TensorFlow.js
TensorFlow 코어