이 튜토리얼에서는 정수 덧셈을 수행하도록 RNN( Recurrent Neural Network )을 교육하기 위해 웹 작업자를 사용하는 웹 애플리케이션의 예를 살펴보겠습니다. 예제 앱은 더하기 연산자를 명시적으로 정의하지 않습니다. 대신 예제 합계를 사용하여 RNN을 훈련합니다.
물론 이것은 두 개의 정수를 더하는 가장 효율적인 방법은 아닙니다! 하지만 튜토리얼에서는 웹 ML의 중요한 기술, 즉 UI 로직을 처리하는 메인 스레드를 차단하지 않고 장기 실행 계산을 수행하는 방법을 보여줍니다.
이 튜토리얼의 예제 애플리케이션은 온라인으로 제공 되므로 코드를 다운로드하거나 개발 환경을 설정할 필요가 없습니다. 코드를 로컬에서 실행하려면 로컬에서 예제 실행 의 선택적 단계를 완료하세요. 개발 환경을 설정하지 않으려면 예제 탐색 으로 건너뛸 수 있습니다.
예제 코드는 GitHub 에서 사용할 수 있습니다.
(선택 사항) 로컬에서 예제 실행
전제조건
예제 앱을 로컬에서 실행하려면 개발 환경에 다음이 설치되어 있어야 합니다.
예제 앱 설치 및 실행
-
tfjs-examples
저장소를 복제하거나 다운로드하세요. addition-rnn-webworker
디렉터리로 변경합니다.cd tfjs-examples/addition-rnn-webworker
종속성을 설치합니다.
yarn
개발 서버를 시작합니다.
yarn run watch
예제 살펴보기
예제 앱을 엽니다 . (또는 예제를 로컬에서 실행하는 경우 브라우저에서 http://localhost:1234
로 이동합니다.)
TensorFlow.js: Addition RNN 이라는 제목의 페이지가 표시됩니다. 지침에 따라 앱을 사용해 보세요.
웹 양식을 사용하면 다음을 포함하여 모델 학습에 사용되는 일부 매개변수를 업데이트할 수 있습니다.
- Digits : 추가할 용어의 최대 자릿수입니다.
- Training Size : 생성할 훈련 예제의 수입니다.
- RNN 유형 : SimpleRNN , GRU 또는 LSTM 중 하나입니다.
- RNN Hidden Layer Size : 출력 공간의 차원입니다(양의 정수여야 함).
- 배치 크기 : 그라데이션 업데이트당 샘플 수입니다.
- Train Iterations :
model.fit()
호출하여 모델을 교육하는 횟수입니다. - # of test example : 생성할 예제 문자열의 수(예:
27+41
)입니다.
다양한 매개변수를 사용하여 모델을 훈련하고 다양한 숫자 집합에 대한 예측 정확도를 향상시킬 수 있는지 확인하세요. 또한 모델 적합 시간이 다양한 매개변수에 의해 어떻게 영향을 받는지 확인하세요.
코드 살펴보기
예제 앱은 RNN 교육을 위해 구성할 수 있는 일부 매개변수를 보여줍니다. 또한 웹 작업자를 사용하여 메인 스레드에서 모델을 훈련시키는 방법도 보여줍니다. 웹 작업자는 백그라운드 스레드에서 계산 비용이 많이 드는 훈련 작업을 실행할 수 있게 하여 잠재적으로 사용자에게 영향을 미치는 메인 스레드의 성능 문제를 방지할 수 있기 때문에 웹 ML에서 중요합니다. 기본 스레드와 작업자 스레드는 메시지 이벤트를 통해 서로 통신합니다.
웹 작업자에 대해 자세히 알아보려면 웹 작업자 API 및 웹 작업자 사용을 참조하세요.
예제 앱의 기본 모듈은 index.js
입니다. index.js
스크립트는 worker.js
모듈을 실행하는 웹 작업자를 생성합니다 .
const worker =
new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});
index.js
는 양식 제출을 처리하고, 양식 데이터를 처리하고, 양식 데이터를 작업자에게 전달하고, 작업자가 모델을 훈련하고 결과를 반환할 때까지 기다린 후 페이지에 결과를 표시하는 단일 함수 runAdditionRNNDemo
로 크게 구성됩니다. .
양식 데이터를 작업자에게 보내기 위해 스크립트는 작업자에서 postMessage
호출합니다 .
worker.postMessage({
digits,
trainingSize,
rnnType,
layers,
hiddenSize,
trainIterations,
batchSize,
numTestExamples
});
작업자는 이 메시지를 수신 하고 데이터를 준비하고 훈련을 시작하는 함수에 양식 데이터를 전달합니다.
self.addEventListener('message', async (e) => {
const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
await demo.train(trainIterations, batchSize, numTestExamples);
})
훈련 중에 작업자는 isPredict
true true
설정된 두 가지 메시지 유형을 보낼 수 있습니다.
self.postMessage({
isPredict: true,
i, iterations, modelFitTime,
lossValues, accuracyValues,
});
다른 하나는 isPredict
false
로 설정되어 있습니다.
self.postMessage({
isPredict: false,
isCorrect, examples
});
UI 스레드( index.js
)는 메시지 이벤트를 처리할 때 isPredict
플래그를 확인하여 작업자에서 반환된 데이터의 형태를 결정합니다. isPredict
가 true인 경우 데이터는 예측을 나타내야 하며 스크립트는 tfjs-vis
사용하여 페이지를 업데이트합니다 . isPredict
가 false인 경우 스크립트는 데이터가 예제를 나타낸다고 가정하는 코드 블록을 실행합니다. 데이터를 HTML로 래핑하고 HTML을 페이지에 삽입합니다.
다음은 무엇입니까?
이 튜토리얼에서는 장기 실행 학습 프로세스로 인해 UI 스레드가 차단되는 것을 방지하기 위해 웹 작업자를 사용하는 예를 제공했습니다. 백그라운드 스레드에서 비용이 많이 드는 계산을 수행할 때의 이점에 대해 자세히 알아보려면 웹 작업자를 사용하여 브라우저의 기본 스레드에서 JavaScript 실행을 참조하세요.
TensorFlow.js 모델 학습에 대해 자세히 알아보려면 모델 학습을 참조하세요.