Usa un modello pre-addestrato

In questo tutorial esplorerai un'applicazione web di esempio che dimostra l'apprendimento del trasferimento utilizzando l'API TensorFlow.js Layers. L'esempio carica un modello pre-addestrato e quindi riqualifica il modello nel browser.

Il modello è stato pre-addestrato in Python sulle cifre 0-4 del set di dati di classificazione delle cifre MNIST . La riqualificazione (o trasferimento dell'apprendimento) nel browser utilizza le cifre 5-9. L'esempio mostra che i primi livelli di un modello pre-addestrato possono essere utilizzati per estrarre funzionalità dai nuovi dati durante l'apprendimento del trasferimento, consentendo così un addestramento più rapido sui nuovi dati.

L'applicazione di esempio per questo tutorial è disponibile online , quindi non è necessario scaricare alcun codice o configurare un ambiente di sviluppo. Se desideri eseguire il codice localmente, completa i passaggi facoltativi in ​​Eseguire l'esempio localmente . Se non desideri configurare un ambiente di sviluppo, puoi passare a Esplora l'esempio .

Il codice di esempio è disponibile su GitHub .

(Facoltativo) Esegui l'esempio localmente

Prerequisiti

Per eseguire l'app di esempio localmente, è necessario che sia installato quanto segue nel tuo ambiente di sviluppo:

Installa ed esegui l'app di esempio

  1. Clona o scarica il repository tfjs-examples .
  2. Passare alla directory mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Installa le dipendenze:

    yarn
    
  4. Avviare il server di sviluppo:

    yarn run watch
    

Esplora l'esempio

Apri l'app di esempio . (Oppure, se stai eseguendo l'esempio localmente, vai a http://localhost:1234 nel tuo browser.)

Dovresti vedere una pagina intitolata MNIST CNN Transfer Learning . Segui le istruzioni per provare l'app.

Ecco alcune cose da provare:

  • Sperimenta le diverse modalità di allenamento e confronta perdita e precisione.
  • Seleziona diversi esempi di bitmap e controlla le probabilità di classificazione. Tieni presente che i numeri in ciascun esempio di bitmap sono valori interi in scala di grigi che rappresentano i pixel di un'immagine.
  • Modifica i valori interi della bitmap e osserva come le modifiche influiscono sulle probabilità di classificazione.

Esplora il codice

L'app Web di esempio carica un modello che è stato pre-addestrato su un sottoinsieme del set di dati MNIST. Il pre-addestramento è definito in un programma Python: mnist_transfer_cnn.py . Il programma Python non rientra nell'ambito di questo tutorial, ma vale la pena esaminarlo se desideri vedere un esempio di conversione del modello .

Il file index.js contiene la maggior parte del codice di training per la demo. Quando index.js viene eseguito nel browser, una funzione di configurazione, setupMnistTransferCNN , istanzia e inizializza MnistTransferCNNPredictor , che incapsula le routine di riqualificazione e previsione.

Il metodo di inizializzazione, MnistTransferCNNPredictor.init , carica un modello, carica i dati di riaddestramento e crea dati di test. Ecco la riga che carica il modello:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Se guardi la definizione di loader.loadHostedPretrainedModel , vedrai che restituisce il risultato di una chiamata a tf.loadLayersModel . Questa è l'API TensorFlow.js per caricare un modello composto da oggetti Layer.

La logica di riqualificazione è definita in MnistTransferCNNPredictor.retrainModel . Se l'utente ha selezionato Congela feature layer come modalità di training, i primi 7 layer del modello base vengono congelati e solo gli ultimi 5 layer vengono addestrati sui nuovi dati. Se l'utente ha selezionato Reinizializza pesi , tutti i pesi vengono reimpostati e l'app addestra effettivamente il modello da zero.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Il modello viene quindi compilato e quindi addestrato sui dati di test utilizzando model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Per ulteriori informazioni sui parametri model.fit() , consulta la documentazione dell'API .

Dopo essere stato addestrato sul nuovo set di dati (cifre 5-9), il modello può essere utilizzato per fare previsioni. Il metodo MnistTransferCNNPredictor.predict esegue questa operazione utilizzando model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Da notare l'uso di tf.tidy , che aiuta a prevenire perdite di memoria.

Saperne di più

Questo tutorial ha esplorato un'app di esempio che esegue l'apprendimento del trasferimento nel browser utilizzando TensorFlow.js. Consulta le risorse riportate di seguito per ulteriori informazioni sui modelli pre-addestrati e sul trasferimento dell'apprendimento.

TensorFlow.js

Nucleo TensorFlow