W tym samouczku poznasz przykładową aplikację internetową, która demonstruje transfer uczenia się przy użyciu interfejsu API warstw TensorFlow.js. Przykład ładuje wstępnie wytrenowany model, a następnie ponownie szkoli model w przeglądarce.
Model został wstępnie przeszkolony w języku Python na cyfrach 0–4 zbioru danych klasyfikacji cyfr MNIST . Do przekwalifikowania (lub przeniesienia nauki) w przeglądarce wykorzystywane są cyfry 5-9. Przykład pokazuje, że pierwszych kilka warstw wstępnie wyszkolonego modelu można wykorzystać do wyodrębnienia funkcji z nowych danych podczas uczenia transferowego, umożliwiając w ten sposób szybsze szkolenie na nowych danych.
Przykładowa aplikacja do tego samouczka jest dostępna online , więc nie musisz pobierać żadnego kodu ani konfigurować środowiska programistycznego. Jeśli chcesz uruchomić kod lokalnie, wykonaj opcjonalne kroki w temacie Lokalne uruchamianie przykładu . Jeśli nie chcesz konfigurować środowiska programistycznego, możesz przejść do sekcji Poznaj przykład .
Przykładowy kod jest dostępny na GitHubie .
(Opcjonalnie) Uruchom przykład lokalnie
Warunki wstępne
Aby uruchomić przykładową aplikację lokalnie, w środowisku programistycznym muszą być zainstalowane następujące elementy:
- Node.js ( pobierz )
- Przędza ( zainstaluj )
Zainstaluj i uruchom przykładową aplikację
- Sklonuj lub pobierz repozytorium
tfjs-examples
. Przejdź do katalogu
mnist-transfer-cnn
:cd tfjs-examples/mnist-transfer-cnn
Zainstaluj zależności:
yarn
Uruchom serwer deweloperski:
yarn run watch
Przeanalizuj przykład
Otwórz przykładową aplikację . (Lub, jeśli uruchamiasz przykład lokalnie, przejdź do http://localhost:1234
w swojej przeglądarce.)
Powinieneś zobaczyć stronę zatytułowaną MNIST CNN Transfer Learning . Postępuj zgodnie z instrukcjami, aby wypróbować aplikację.
Oto kilka rzeczy do wypróbowania:
- Eksperymentuj z różnymi trybami treningu i porównuj straty i dokładność.
- Wybierz różne przykłady bitmap i sprawdź prawdopodobieństwa klasyfikacji. Należy pamiętać, że liczby w każdym przykładzie mapy bitowej są wartościami całkowitymi w skali szarości reprezentującymi piksele obrazu.
- Edytuj wartości całkowite mapy bitowej i zobacz, jak zmiany wpływają na prawdopodobieństwa klasyfikacji.
Poznaj kod
Przykładowa aplikacja internetowa ładuje model, który został wstępnie przeszkolony na podzbiorze zbioru danych MNIST. Trening wstępny zdefiniowany jest w programie w Pythonie: mnist_transfer_cnn.py
. Program w języku Python wykracza poza zakres tego samouczka, ale warto go obejrzeć, jeśli chcesz zobaczyć przykład konwersji modelu .
Plik index.js
zawiera większość kodu szkoleniowego wersji demonstracyjnej. Gdy index.js
działa w przeglądarce, funkcja instalacyjna setupMnistTransferCNN
tworzy instancję i inicjuje MnistTransferCNNPredictor
, który zawiera procedury ponownego szkolenia i przewidywania.
Metoda inicjalizacji MnistTransferCNNPredictor.init
ładuje model, ładuje dane do ponownego szkolenia i tworzy dane testowe. Oto linia ładująca model:
this.model = await loader.loadHostedPretrainedModel(urls.model);
Jeśli spojrzysz na definicję loader.loadHostedPretrainedModel
, zobaczysz, że zwraca ona wynik wywołania metody tf.loadLayersModel
. To jest interfejs API TensorFlow.js służący do ładowania modelu złożonego z obiektów warstwy.
Logika ponownego szkolenia jest zdefiniowana w MnistTransferCNNPredictor.retrainModel
. Jeśli użytkownik wybrał opcję Zablokuj warstwy obiektowe jako tryb uczenia, pierwsze 7 warstw modelu podstawowego zostanie zablokowanych, a tylko ostatnie 5 warstw będzie trenowanych na nowych danych. Jeśli użytkownik wybrał opcję Reinicjuj wagi , wszystkie wagi zostaną zresetowane, a aplikacja skutecznie nauczy model od zera.
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
Następnie model jest kompilowany , a następnie szkolony na danych testowych za pomocą model.fit()
:
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
Więcej informacji na temat parametrów model.fit()
można znaleźć w dokumentacji API .
Po przeszkoleniu na nowym zbiorze danych (cyfry 5–9) model można wykorzystać do przewidywania. Metoda MnistTransferCNNPredictor.predict
robi to za pomocą model.predict()
:
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
Zwróć uwagę na użycie tf.tidy
, które pomaga zapobiegać wyciekom pamięci.
Dowiedz się więcej
W tym samouczku omówiono przykładową aplikację, która wykonuje naukę transferu w przeglądarce przy użyciu TensorFlow.js. Zapoznaj się z poniższymi zasobami, aby dowiedzieć się więcej o wstępnie wytrenowanych modelach i transferze uczenia się.
TensorFlow.js
- Importowanie modelu Keras do TensorFlow.js
- Zaimportuj model TensorFlow do TensorFlow.js
- Gotowe modele dla TensorFlow.js
Rdzeń TensorFlow