TensorFlow.js zapewnia funkcjonalność zapisywania i ładowania modeli utworzonych za pomocą interfejsu Layers
API lub skonwertowanych z istniejących modeli TensorFlow. Mogą to być modele, które sam wytrenowałeś lub które przeszkolili inni. Kluczową zaletą korzystania z interfejsu API Layers jest to, że utworzone za jego pomocą modele można serializować i właśnie to omówimy w tym samouczku.
Ten samouczek skupi się na zapisywaniu i ładowaniu modeli TensorFlow.js (identyfikowanych przez pliki JSON). Możemy również importować modele TensorFlow Python. Ładowanie tych modeli opisano w dwóch poniższych samouczkach:
Zapisz model tf
Zarówno tf.Model
, jak i tf.Sequential
udostępniają funkcję model.save
, która pozwala zapisać topologię i wagi modelu.
Topologia: Jest to plik opisujący architekturę modelu (tj. jakie operacje wykorzystuje). Zawiera odniesienia do ciężarów modeli, które są przechowywane zewnętrznie.
Wagi: Są to pliki binarne przechowujące wagi danego modelu w wydajnym formacie. Zazwyczaj są one przechowywane w tym samym folderze co topologia.
Przyjrzyjmy się jak wygląda kod zapisu modelu
const saveResult = await model.save('localstorage://my-model-1');
Kilka rzeczy, na które warto zwrócić uwagę:
- Metoda
save
przyjmuje argument w postaci ciągu znaków przypominający adres URL, który zaczyna się od schematu . Opisuje typ miejsca docelowego, w którym próbujemy zapisać model. W powyższym przykładzie schemat tolocalstorage://
- Po schemacie następuje ścieżka . W powyższym przykładzie ścieżka to
my-model-1
. - Metoda
save
jest asynchroniczna. - Wartość zwracana przez
model.save
to obiekt JSON, który zawiera informacje, takie jak wielkość bajtów topologii i wagi modelu. - Środowisko użyte do zapisania modelu nie ma wpływu na to, które środowiska mogą załadować model. Zapisanie modelu w node.js nie uniemożliwia załadowania go w przeglądarce.
Poniżej przeanalizujemy różne dostępne schematy.
Pamięć lokalna (tylko przeglądarka)
Schemat: localstorage://
await model.save('localstorage://my-model');
Spowoduje to zapisanie modelu pod nazwą my-model
w pamięci lokalnej przeglądarki. Będzie się to utrzymywać między odświeżeniami, chociaż pamięć lokalna może zostać wyczyszczona przez użytkowników lub samą przeglądarkę, jeśli problemem stanie się miejsce. Każda przeglądarka ustala także własny limit ilości danych, które mogą być przechowywane w pamięci lokalnej dla danej domeny.
IndexedDB (tylko przeglądarka)
Schemat: indexeddb://
await model.save('indexeddb://my-model');
Spowoduje to zapisanie modelu w pamięci IndexedDB przeglądarki. Podobnie jak pamięć lokalna, utrzymuje się ona pomiędzy odświeżeniami, ma również większe ograniczenia dotyczące rozmiaru przechowywanych obiektów.
Pobieranie plików (tylko przeglądarka)
Schemat: downloads://
await model.save('downloads://my-model');
Spowoduje to, że przeglądarka pobierze pliki modelu na komputer użytkownika. Zostaną utworzone dwa pliki:
- Tekstowy plik JSON o nazwie
[my-model].json
, który zawiera topologię i odniesienie do pliku wag opisanego poniżej. - Plik binarny zawierający wartości wag o nazwie
[my-model].weights.bin
.
Możesz zmienić nazwę [my-model]
aby uzyskać pliki o innej nazwie.
Ponieważ plik .json
wskazuje na plik .bin
przy użyciu ścieżki względnej, oba pliki powinny znajdować się w tym samym folderze.
Żądanie HTTP(S).
Schemat: http://
lub https://
await model.save('http://model-server.domain/upload')
Spowoduje to utworzenie żądania internetowego w celu zapisania modelu na zdalnym serwerze. Powinieneś mieć kontrolę nad tym zdalnym serwerem, aby mieć pewność, że jest on w stanie obsłużyć żądanie.
Model zostanie wysłany do określonego serwera HTTP za pośrednictwem żądania POST . Treść testu POST jest w formacie multipart/form-data
i składa się z dwóch plików
- Tekstowy plik JSON o nazwie
model.json
, który zawiera topologię i odniesienie do pliku wag opisanego poniżej. - Plik binarny zawierający wartości wag o nazwie
model.weights.bin
.
Należy pamiętać, że nazwa obu plików będzie zawsze dokładnie taka, jak określono powyżej (nazwa jest wbudowana w funkcję). Ten dokument interfejsu API zawiera fragment kodu Pythona, który demonstruje, w jaki sposób można wykorzystać framework sieciowy Flask do obsługi żądania pochodzącego z save
.
Często będziesz musiał przekazać więcej argumentów lub nagłówków żądań do swojego serwera HTTP (np. w celu uwierzytelnienia lub jeśli chcesz określić folder, w którym model powinien zostać zapisany). Możesz uzyskać szczegółową kontrolę nad tymi aspektami żądań z save
zastępując argument ciągu adresu URL w tf.io.browserHTTPRequest
. Ten interfejs API zapewnia większą elastyczność w kontrolowaniu żądań HTTP.
Na przykład:
await model.save(tf.io.browserHTTPRequest(
'http://model-server.domain/upload',
{method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));
Natywny system plików (tylko Node.js)
Schemat: file://
await model.save('file:///path/to/my-model');
Uruchamiając na Node.js mamy także bezpośredni dostęp do systemu plików i możemy tam zapisywać modele. Powyższe polecenie zapisze dwa pliki w path
określonej według scheme
.
- Tekstowy plik JSON o nazwie
[model].json
, który zawiera topologię i odniesienie do pliku wag opisanego poniżej. - Plik binarny zawierający wartości wag o nazwie
[model].weights.bin
.
Należy pamiętać, że nazwa obu plików będzie zawsze dokładnie taka, jak określono powyżej (nazwa jest wbudowana w funkcję).
Ładowanie modelu tf
Mając model, który został zapisany jedną z powyższych metod, możemy go załadować za pomocą API tf.loadLayersModel
.
Przyjrzyjmy się jak wygląda kod ładujący model
const model = await tf.loadLayersModel('localstorage://my-model-1');
Kilka rzeczy, na które warto zwrócić uwagę:
- Podobnie jak
model.save()
, funkcjaloadLayersModel
przyjmuje argument w postaci ciągu znaków przypominający adres URL, rozpoczynający się od schematu . Opisuje typ miejsca docelowego, z którego próbujemy załadować model. - Po schemacie następuje ścieżka . W powyższym przykładzie ścieżka to
my-model-1
. - Ciąg przypominający adres URL można zastąpić obiektem pasującym do interfejsu IOHandler.
- Funkcja
tf.loadLayersModel()
jest asynchroniczna. - Wartość zwracana przez
tf.loadLayersModel
totf.Model
Poniżej przeanalizujemy różne dostępne schematy.
Pamięć lokalna (tylko przeglądarka)
Schemat: localstorage://
const model = await tf.loadLayersModel('localstorage://my-model');
Spowoduje to załadowanie modelu o nazwie my-model
z pamięci lokalnej przeglądarki.
IndexedDB (tylko przeglądarka)
Schemat: indexeddb://
const model = await tf.loadLayersModel('indexeddb://my-model');
Spowoduje to załadowanie modelu z pamięci IndexedDB przeglądarki.
HTTP(S)
Schemat: http://
lub https://
const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');
Spowoduje to załadowanie modelu z punktu końcowego http. Po załadowaniu pliku json
funkcja będzie wysyłać żądania dotyczące odpowiednich plików .bin
, do których odwołuje się plik json
.
Natywny system plików (tylko Node.js)
Schemat: file://
const model = await tf.loadLayersModel('file://path/to/my-model/model.json');
Działając na Node.js mamy również bezpośredni dostęp do systemu plików i możemy stamtąd ładować modele. Należy pamiętać, że w powyższym wywołaniu funkcji odwołujemy się do samego pliku model.json (podczas zapisywania określamy folder). Odpowiednie pliki .bin
powinny znajdować się w tym samym folderze, co plik json
.
Ładowanie modeli za pomocą IOHhandlerów
Jeśli powyższe schematy nie są wystarczające dla Twoich potrzeb, możesz zaimplementować niestandardowe zachowanie ładowania za pomocą IOHandler
. Jednym z IOHandler
udostępnianym przez TensorFlow.js jest tf.io.browserFiles
, który umożliwia użytkownikom przeglądarki przesyłanie plików modeli do przeglądarki. Więcej informacji można znaleźć w dokumentacji .
Zapisywanie i ładowanie modeli za pomocą niestandardowych IOHhandlerów
Jeśli powyższe schematy nie są wystarczające dla Twoich potrzeb związanych z ładowaniem lub zapisywaniem, możesz zaimplementować niestandardowe zachowanie serializacji, implementując IOHandler
.
IOHandler
to obiekt z metodą save
i load
.
Funkcja save
przyjmuje jeden parametr, który jest zgodny z interfejsem ModelArtifacts i powinna zwracać obietnicę, która prowadzi do obiektu SaveResult .
Funkcja load
nie przyjmuje żadnych parametrów i powinna zwracać obietnicę, która prowadzi do obiektu ModelArtifacts . Jest to ten sam obiekt, który jest przekazywany do save
.
Zobacz BrowserHTTPRequest, aby zapoznać się z przykładem implementacji IOHandler.