TensorFlow.js предоставляет функциональные возможности для сохранения и загрузки моделей, созданных с помощью Layers
API или преобразованных из существующих моделей TensorFlow. Это могут быть модели, которых вы обучили сами, или модели, обученные другими. Ключевым преимуществом использования Layers API является то, что модели, созданные с его помощью, можно сериализовать, и это то, что мы рассмотрим в этом руководстве.
В этом руководстве основное внимание будет уделено сохранению и загрузке моделей TensorFlow.js (идентифицируемых по файлам JSON). Мы также можем импортировать модели TensorFlow Python. Загрузка этих моделей описана в следующих двух руководствах:
Сохраните tf.Model
tf.Model
и tf.Sequential
предоставляют функцию model.save
, которая позволяет сохранять топологию и веса модели.
Топология: это файл, описывающий архитектуру модели (т.е. какие операции она использует). Он содержит ссылки на веса моделей, которые хранятся снаружи.
Веса: это двоичные файлы, в которых хранятся веса данной модели в эффективном формате. Обычно они хранятся в той же папке, что и топология.
Давайте посмотрим, как выглядит код сохранения модели
const saveResult = await model.save('localstorage://my-model-1');
Несколько вещей, на которые следует обратить внимание:
- Метод
save
принимает строковый аргумент, подобный URL-адресу, который начинается со схемы . Это описывает тип места назначения, в котором мы пытаемся сохранить модель. В приведенном выше примере схема —localstorage://
- По схеме следует путь . В приведенном выше примере путь —
my-model-1
. - Метод
save
является асинхронным. - Возвращаемое значение
model.save
— это объект JSON, который содержит такую информацию, как размеры топологии модели в байтах и ее веса. - Среда, используемая для сохранения модели, не влияет на то, какие среды могут загружать модель. Сохранение модели в node.js не препятствует ее загрузке в браузере.
Ниже мы рассмотрим различные доступные схемы.
Локальное хранилище (только браузер)
Схема: localstorage://
await model.save('localstorage://my-model');
При этом модель сохраняется под именем my-model
в локальном хранилище браузера. Это будет сохраняться между обновлениями, хотя локальное хранилище может быть очищено пользователями или самим браузером, если пространство становится проблемой. Каждый браузер также устанавливает свой собственный лимит на объем данных, которые могут храниться в локальном хранилище для данного домена.
IndexedDB (только браузер)
Схема: indexeddb://
await model.save('indexeddb://my-model');
При этом модель сохраняется в хранилище IndexedDB браузера. Как и локальное хранилище, оно сохраняется между обновлениями, но также имеет тенденцию иметь более высокие ограничения на размер хранимых объектов.
Загрузка файлов (только браузер)
Схема: downloads://
await model.save('downloads://my-model');
Это заставит браузер загрузить файлы модели на компьютер пользователя. Будут созданы два файла:
- Текстовый файл JSON с именем
[my-model].json
, который содержит топологию и ссылку на файл весов, описанный ниже. - Двоичный файл, содержащий значения веса, с именем
[my-model].weights.bin
.
Вы можете изменить имя [my-model]
, чтобы получить файлы с другим именем.
Поскольку файл .json
указывает на .bin
по относительному пути, эти два файла должны находиться в одной папке.
HTTP(S) запрос
Схема: http://
или https://
await model.save('http://model-server.domain/upload')
Это создаст веб-запрос для сохранения модели на удаленном сервере. Вы должны контролировать этот удаленный сервер, чтобы быть уверенным, что он сможет обработать запрос.
Модель будет отправлена на указанный HTTP-сервер посредством POST- запроса. Тело POST имеет формат multipart/form-data
и состоит из двух файлов.
- Текстовый файл JSON с именем
model.json
, который содержит топологию и ссылку на файл весов, описанный ниже. - Бинарный файл, содержащий значения веса, с именем
model.weights.bin
.
Обратите внимание, что имена двух файлов всегда будут такими, как указано выше (имя встроено в функцию). Этот API-документ содержит фрагмент кода Python, который демонстрирует, как можно использовать веб-инфраструктуру flask для обработки запроса, исходящего от save
.
Часто вам придется передавать на HTTP-сервер дополнительные аргументы или заголовки запросов (например, для аутентификации или если вы хотите указать папку, в которой должна быть сохранена модель). Вы можете получить детальный контроль над этими аспектами запросов от save
, заменив аргумент строки URL-адреса в tf.io.browserHTTPRequest
. Этот API обеспечивает большую гибкость в управлении HTTP-запросами.
Например:
await model.save(tf.io.browserHTTPRequest(
'http://model-server.domain/upload',
{method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));
Собственная файловая система (только Node.js)
Схема: file://
await model.save('file:///path/to/my-model');
При работе на Node.js мы также имеем прямой доступ к файловой системе и можем сохранять модели там. Приведенная выше команда сохранит два файла по path
, указанному после scheme
.
- Текстовый файл JSON с именем
[model].json
, который содержит топологию и ссылку на файл весов, описанный ниже. - Двоичный файл, содержащий значения веса, с именем
[model].weights.bin
.
Обратите внимание, что имена двух файлов всегда будут такими, как указано выше (имя встроено в функцию).
Загрузка tf.Model
Учитывая модель, сохраненную с помощью одного из вышеперечисленных методов, мы можем загрузить ее с помощью API tf.loadLayersModel
.
Давайте посмотрим, как выглядит код загрузки модели
const model = await tf.loadLayersModel('localstorage://my-model-1');
Несколько вещей, на которые следует обратить внимание:
- Как и
model.save()
, функцияloadLayersModel
принимает строковый аргумент типа URL, который начинается со схемы . Это описывает тип места назначения, из которого мы пытаемся загрузить модель. - По схеме следует путь . В приведенном выше примере путь —
my-model-1
. - Строку, подобную URL-адресу, можно заменить объектом, соответствующим интерфейсу IOHandler.
- Функция
tf.loadLayersModel()
является асинхронной. - Возвращаемое значение
tf.loadLayersModel
—tf.Model
Ниже мы рассмотрим различные доступные схемы.
Локальное хранилище (только браузер)
Схема: localstorage://
const model = await tf.loadLayersModel('localstorage://my-model');
Это загружает модель с именем my-model
из локального хранилища браузера.
IndexedDB (только браузер)
Схема: indexeddb://
const model = await tf.loadLayersModel('indexeddb://my-model');
При этом модель загружается из хранилища IndexedDB браузера.
HTTP(S)
Схема: http://
или https://
const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');
Это загружает модель из конечной точки http. После загрузки файла json
функция выполнит запросы на соответствующие файлы .bin
, на которые ссылается файл json
.
Собственная файловая система (только Node.js)
Схема: file://
const model = await tf.loadLayersModel('file://path/to/my-model/model.json');
При работе на Node.js мы также имеем прямой доступ к файловой системе и можем загружать модели оттуда. Обратите внимание, что в вызове функции выше мы ссылаемся на сам файл model.json (тогда как при сохранении мы указываем папку). Соответствующие файлы .bin
должны находиться в той же папке, что и файл json
.
Загрузка моделей с помощью IOHandlers
Если приведенных выше схем недостаточно для ваших нужд, вы можете реализовать собственное поведение загрузки с помощью IOHandler
. Один IOHandler
, который предоставляет TensorFlow.js, — это tf.io.browserFiles
, который позволяет пользователям браузера загружать файлы модели в браузер. Дополнительную информацию смотрите в документации .
Сохранение и загрузка моделей с помощью пользовательских обработчиков ввода-вывода
Если приведенных выше схем недостаточно для ваших потребностей в загрузке или сохранении, вы можете реализовать собственное поведение сериализации, реализовав IOHandler
.
IOHandler
— это объект с методом save
и load
.
Функция save
принимает один параметр, который соответствует интерфейсу ModelArtifacts и должен возвращать обещание, которое разрешается в объект SaveResult .
Функция load
не принимает параметров и должна возвращать обещание, которое разрешается в объект ModelArtifacts . Это тот же объект, который передается в save
.
См. BrowserHTTPRequest для примера реализации IOHandler.