Użyj wstępnie wytrenowanego modelu

W tym samouczku poznasz przykładową aplikację internetową, która demonstruje transfer uczenia się przy użyciu interfejsu API warstw TensorFlow.js. Przykład ładuje wstępnie wytrenowany model, a następnie ponownie szkoli model w przeglądarce.

Model został wstępnie przeszkolony w języku Python na cyfrach 0–4 zbioru danych klasyfikacji cyfr MNIST . Do przekwalifikowania (lub przeniesienia nauki) w przeglądarce wykorzystywane są cyfry 5-9. Przykład pokazuje, że pierwszych kilka warstw wstępnie wyszkolonego modelu można wykorzystać do wyodrębnienia funkcji z nowych danych podczas uczenia transferowego, umożliwiając w ten sposób szybsze szkolenie na nowych danych.

Przykładowa aplikacja do tego samouczka jest dostępna online , więc nie musisz pobierać żadnego kodu ani konfigurować środowiska programistycznego. Jeśli chcesz uruchomić kod lokalnie, wykonaj opcjonalne kroki w temacie Lokalne uruchamianie przykładu . Jeśli nie chcesz konfigurować środowiska programistycznego, możesz przejść do sekcji Poznaj przykład .

Przykładowy kod jest dostępny na GitHubie .

(Opcjonalnie) Uruchom przykład lokalnie

Warunki wstępne

Aby uruchomić przykładową aplikację lokalnie, w środowisku programistycznym muszą być zainstalowane następujące elementy:

Zainstaluj i uruchom przykładową aplikację

  1. Sklonuj lub pobierz repozytorium tfjs-examples .
  2. Przejdź do katalogu mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Zainstaluj zależności:

    yarn
    
  4. Uruchom serwer deweloperski:

    yarn run watch
    

Przeanalizuj przykład

Otwórz przykładową aplikację . (Lub, jeśli uruchamiasz przykład lokalnie, przejdź do http://localhost:1234 w swojej przeglądarce.)

Powinieneś zobaczyć stronę zatytułowaną MNIST CNN Transfer Learning . Postępuj zgodnie z instrukcjami, aby wypróbować aplikację.

Oto kilka rzeczy do wypróbowania:

  • Eksperymentuj z różnymi trybami treningu i porównuj straty i dokładność.
  • Wybierz różne przykłady bitmap i sprawdź prawdopodobieństwa klasyfikacji. Należy pamiętać, że liczby w każdym przykładzie mapy bitowej są wartościami całkowitymi w skali szarości reprezentującymi piksele obrazu.
  • Edytuj wartości całkowite mapy bitowej i zobacz, jak zmiany wpływają na prawdopodobieństwa klasyfikacji.

Poznaj kod

Przykładowa aplikacja internetowa ładuje model, który został wstępnie przeszkolony na podzbiorze zbioru danych MNIST. Trening wstępny zdefiniowany jest w programie w Pythonie: mnist_transfer_cnn.py . Program w języku Python wykracza poza zakres tego samouczka, ale warto go obejrzeć, jeśli chcesz zobaczyć przykład konwersji modelu .

Plik index.js zawiera większość kodu szkoleniowego wersji demonstracyjnej. Gdy index.js działa w przeglądarce, funkcja instalacyjna setupMnistTransferCNN tworzy instancję i inicjuje MnistTransferCNNPredictor , który zawiera procedury ponownego szkolenia i przewidywania.

Metoda inicjalizacji MnistTransferCNNPredictor.init ładuje model, ładuje dane do ponownego szkolenia i tworzy dane testowe. Oto linia ładująca model:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Jeśli spojrzysz na definicję loader.loadHostedPretrainedModel , zobaczysz, że zwraca ona wynik wywołania metody tf.loadLayersModel . To jest interfejs API TensorFlow.js służący do ładowania modelu złożonego z obiektów warstwy.

Logika ponownego szkolenia jest zdefiniowana w MnistTransferCNNPredictor.retrainModel . Jeśli użytkownik wybrał opcję Zablokuj warstwy obiektowe jako tryb uczenia, pierwsze 7 warstw modelu podstawowego zostanie zablokowanych, a tylko ostatnie 5 warstw będzie trenowanych na nowych danych. Jeśli użytkownik wybrał opcję Reinicjuj wagi , wszystkie wagi zostaną zresetowane, a aplikacja skutecznie nauczy model od zera.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Następnie model jest kompilowany , a następnie szkolony na danych testowych za pomocą model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Więcej informacji na temat parametrów model.fit() można znaleźć w dokumentacji API .

Po przeszkoleniu na nowym zbiorze danych (cyfry 5–9) model można wykorzystać do przewidywania. Metoda MnistTransferCNNPredictor.predict robi to za pomocą model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Zwróć uwagę na użycie tf.tidy , które pomaga zapobiegać wyciekom pamięci.

Dowiedz się więcej

W tym samouczku omówiono przykładową aplikację, która wykonuje naukę transferu w przeglądarce przy użyciu TensorFlow.js. Zapoznaj się z poniższymi zasobami, aby dowiedzieć się więcej o wstępnie wytrenowanych modelach i transferze uczenia się.

TensorFlow.js

Rdzeń TensorFlow