Entrenar un modelo usando un trabajador web

En este tutorial, explorará una aplicación web de ejemplo que utiliza un trabajador web para entrenar una red neuronal recurrente (RNN) para realizar sumas de enteros. La aplicación de ejemplo no define explícitamente el operador de suma. En cambio, entrena al RNN usando sumas de ejemplo.

¡Por supuesto, esta no es la forma más eficiente de sumar dos números enteros! Pero el tutorial demuestra una técnica importante en web ML: cómo realizar cálculos de ejecución prolongada sin bloquear el subproceso principal, que maneja la lógica de la interfaz de usuario.

La aplicación de ejemplo para este tutorial está disponible en línea , por lo que no necesita descargar ningún código ni configurar un entorno de desarrollo. Si desea ejecutar el código localmente, complete los pasos opcionales en Ejecutar el ejemplo localmente . Si no desea configurar un entorno de desarrollo, puede pasar a Explorar el ejemplo .

El código de ejemplo está disponible en GitHub .

(Opcional) Ejecute el ejemplo localmente

requisitos previos

Para ejecutar la aplicación de ejemplo localmente, necesita lo siguiente instalado en su entorno de desarrollo:

Instalar y ejecutar la aplicación de ejemplo

  1. Clone o descargue el repositorio tfjs-examples .
  2. Cambie al directorio addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Instalar dependencias:

    yarn
    
  4. Inicie el servidor de desarrollo:

    yarn run watch
    

Explora el ejemplo

Abra la aplicación de ejemplo . (O, si está ejecutando el ejemplo localmente, vaya a http://localhost:1234 en su navegador).

Debería ver una página titulada TensorFlow.js: Addition RNN . Siga las instrucciones para probar la aplicación.

Mediante el formulario web, puede actualizar algunos de los parámetros utilizados para entrenar el modelo, incluidos los siguientes:

  • Dígitos : El número máximo de dígitos en los términos que se agregarán.
  • Tamaño de entrenamiento : El número de ejemplos de entrenamiento para generar.
  • Tipo de RNN : Uno de SimpleRNN , GRU o LSTM .
  • RNN Hidden Layer Size : Dimensionalidad del espacio de salida (debe ser un número entero positivo).
  • Tamaño del lote : número de muestras por actualización de gradiente.
  • Iteraciones de entrenamiento : Número de veces para entrenar el modelo invocando model.fit()
  • # de ejemplos de prueba : número de cadenas de ejemplo (por ejemplo, 27+41 ) para generar.

Intente entrenar el modelo con diferentes parámetros y vea si puede mejorar la precisión de las predicciones para varios conjuntos de dígitos. Observe también cómo el tiempo de ajuste del modelo se ve afectado por diferentes parámetros.

Explora el código

La aplicación de ejemplo demuestra algunos de los parámetros que puede configurar para entrenar un RNN. También demuestra el uso de un trabajador web para entrenar un modelo a partir del hilo principal. Los trabajadores web son importantes en el aprendizaje automático web porque le permiten ejecutar tareas de entrenamiento computacionalmente costosas en un subproceso en segundo plano, lo que evita posibles problemas de rendimiento que afecten al usuario en el subproceso principal. Los subprocesos principal y de trabajo se comunican entre sí a través de eventos de mensajes.

Para obtener más información sobre los trabajadores web, consulte la API de trabajadores web y el uso de trabajadores web .

El módulo principal de la aplicación de ejemplo es index.js . El script index.js crea un web worker que ejecuta el módulo worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js se compone en gran medida de una sola función, runAdditionRNNDemo , que maneja el envío de formularios, procesa los datos del formulario, pasa los datos del formulario al trabajador, espera a que el trabajador entrene el modelo y devuelva los resultados, y luego muestra los resultados en la página .

Para enviar los datos del formulario al trabajador, el script invoca postMessage en el trabajador:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

El trabajador escucha este mensaje y pasa los datos del formulario a las funciones que preparan los datos e inician el entrenamiento:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Durante el entrenamiento, el trabajador puede enviar dos tipos de mensajes diferentes, uno con isPredict establecido en true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

y el otro con isPredict establecido en false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Cuando el subproceso de la interfaz de usuario ( index.js ) maneja eventos de mensajes, verifica el indicador isPredict para determinar la forma de los datos devueltos por el trabajador. Si isPredict es verdadero, los datos deben representar una predicción y el script actualiza la página mediante tfjs-vis . Si isPredict es falso, el script ejecuta un bloque de código que asume que los datos representan ejemplos. Envuelve los datos en HTML e inserta el HTML en la página.

Que sigue

Este tutorial ha brindado un ejemplo del uso de un trabajador web para evitar bloquear el subproceso de la interfaz de usuario con un proceso de capacitación de ejecución prolongada. Para obtener más información sobre los beneficios de realizar cálculos costosos en un subproceso en segundo plano, consulte Usar trabajadores web para ejecutar JavaScript fuera del subproceso principal del navegador .

Para obtener más información sobre cómo entrenar un modelo de TensorFlow.js, consulta Modelos de entrenamiento .