Usa un modelo pre-entrenado

En este tutorial, explorará una aplicación web de ejemplo que demuestra el aprendizaje de transferencia mediante la API de capas de TensorFlow.js. El ejemplo carga un modelo previamente entrenado y luego vuelve a entrenar el modelo en el navegador.

El modelo ha sido entrenado previamente en Python en los dígitos 0-4 del conjunto de datos de clasificación de dígitos MNIST . El reentrenamiento (o transferencia de aprendizaje) en el navegador usa los dígitos 5-9. El ejemplo muestra que las primeras capas de un modelo preentrenado se pueden usar para extraer características de nuevos datos durante el aprendizaje de transferencia, lo que permite un entrenamiento más rápido en los nuevos datos.

La aplicación de ejemplo para este tutorial está disponible en línea , por lo que no necesita descargar ningún código ni configurar un entorno de desarrollo. Si desea ejecutar el código localmente, complete los pasos opcionales en Ejecutar el ejemplo localmente . Si no desea configurar un entorno de desarrollo, puede pasar a Explorar el ejemplo .

El código de ejemplo está disponible en GitHub .

(Opcional) Ejecute el ejemplo localmente

requisitos previos

Para ejecutar la aplicación de ejemplo localmente, necesita lo siguiente instalado en su entorno de desarrollo:

Instalar y ejecutar la aplicación de ejemplo

  1. Clone o descargue el repositorio tfjs-examples .
  2. Cambie al directorio mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Instalar dependencias:

    yarn
    
  4. Inicie el servidor de desarrollo:

    yarn run watch
    

Explora el ejemplo

Abra la aplicación de ejemplo . (O, si está ejecutando el ejemplo localmente, vaya a http://localhost:1234 en su navegador).

Debería ver una página titulada MNIST CNN Transfer Learning . Siga las instrucciones para probar la aplicación.

Aquí hay algunas cosas para probar:

  • Experimente con los diferentes modos de entrenamiento y compare la pérdida y la precisión.
  • Seleccione diferentes ejemplos de mapas de bits e inspeccione las probabilidades de clasificación. Tenga en cuenta que los números en cada ejemplo de mapa de bits son valores enteros en escala de grises que representan píxeles de una imagen.
  • Edite los valores enteros del mapa de bits y vea cómo los cambios afectan las probabilidades de clasificación.

Explora el código

La aplicación web de ejemplo carga un modelo que se ha entrenado previamente en un subconjunto del conjunto de datos MNIST. El entrenamiento previo se define en un programa de Python: mnist_transfer_cnn.py . El programa Python está fuera del alcance de este tutorial, pero vale la pena echarle un vistazo si desea ver un ejemplo de conversión de modelo .

El archivo index.js contiene la mayor parte del código de entrenamiento para la demostración. Cuando index.js se ejecuta en el navegador, una función de configuración, setupMnistTransferCNN , instancia e inicializa MnistTransferCNNPredictor , que encapsula las rutinas de reentrenamiento y predicción.

El método de inicialización, MnistTransferCNNPredictor.init , carga un modelo, carga datos de reentrenamiento y crea datos de prueba. Aquí está la línea que carga el modelo:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Si observa la definición de loader.loadHostedPretrainedModel , verá que devuelve el resultado de una llamada a tf.loadLayersModel . Esta es la API de TensorFlow.js para cargar un modelo compuesto por objetos Layer.

La lógica de reentrenamiento se define en MnistTransferCNNPredictor.retrainModel . Si el usuario ha seleccionado Congelar capas de entidades como modo de entrenamiento, las primeras 7 capas del modelo base se congelan y solo las últimas 5 capas se entrenan con nuevos datos. Si el usuario seleccionó Reinicializar pesos , todos los pesos se restablecen y la aplicación entrena efectivamente al modelo desde cero.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Luego, el modelo se compila y luego se entrena en los datos de prueba usando model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Para obtener más información sobre los parámetros model.fit() , consulte la documentación de la API .

Después de haber sido entrenado en el nuevo conjunto de datos (dígitos 5-9), el modelo se puede usar para hacer predicciones. El método MnistTransferCNNPredictor.predict hace esto usando model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Tenga en cuenta el uso de tf.tidy , que ayuda a evitar pérdidas de memoria.

Aprende más

Este tutorial ha explorado una aplicación de ejemplo que realiza transferencias de aprendizaje en el navegador usando TensorFlow.js. Consulte los recursos a continuación para obtener más información sobre los modelos preentrenados y el aprendizaje de transferencia.

TensorFlow.js

Núcleo TensorFlow