Các mô hình dựa trên TensorFlow GraphDef (thường được tạo thông qua API Python) có thể được lưu ở một trong các định dạng sau:
- TensorFlow SavingMô hình
- Người mẫu đông lạnh
- Mô-đun trung tâm Tensorflow
Tất cả các định dạng trên có thể được chuyển đổi bằng trình chuyển đổi TensorFlow.js thành định dạng có thể tải trực tiếp vào TensorFlow.js để suy luận.
(Lưu ý: TensorFlow không còn dùng định dạng gói phiên nữa. Vui lòng di chuyển mô hình của bạn sang định dạng SavingModel.)
Yêu cầu
Quy trình chuyển đổi yêu cầu môi trường Python; bạn có thể muốn giữ một cái riêng biệt bằng cách sử dụng pipenv hoặc virtualenv .
Để cài đặt bộ chuyển đổi, hãy chạy lệnh sau:
pip install tensorflowjs
Nhập mô hình TensorFlow vào TensorFlow.js là một quá trình gồm hai bước. Đầu tiên, chuyển đổi mô hình hiện có sang định dạng web TensorFlow.js, sau đó tải mô hình đó vào TensorFlow.js.
Bước 1. Chuyển đổi mô hình TensorFlow hiện có sang định dạng web TensorFlow.js
Chạy tập lệnh chuyển đổi được cung cấp bởi gói pip:
Ví dụ về SavingModel:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Ví dụ về mô hình đông lạnh:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Ví dụ về mô-đun Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
Đối số vị trí | Sự miêu tả |
---|---|
input_path | Đường dẫn đầy đủ của thư mục mô hình đã lưu, thư mục gói phiên, tệp mô hình cố định hoặc đường dẫn hoặc điều khiển mô-đun TensorFlow Hub. |
output_path | Đường dẫn cho tất cả các tạo phẩm đầu ra. |
Tùy chọn | Sự miêu tả |
---|---|
--input_format | Định dạng của mô hình đầu vào. Sử dụng tf_saved_model cho SavingModel, tf_frozen_model cho mô hình cố định, tf_session_bundle cho gói phiên, tf_hub cho mô-đun TensorFlow Hub và máy ảnh cho Keras HDF5. |
--output_node_names | Tên của các nút đầu ra, được phân tách bằng dấu phẩy. |
--saved_model_tags | Chỉ áp dụng cho chuyển đổi SavingModel. Các thẻ của MetaGraphDef cần tải, ở định dạng được phân tách bằng dấu phẩy. Mặc định để serve . |
--signature_name | Chỉ áp dụng cho chuyển đổi mô-đun TensorFlow Hub, chữ ký để tải. Mặc định là default . Xem https://www.tensorflow.org/hub/common_signatures/ |
Sử dụng lệnh sau để nhận thông báo trợ giúp chi tiết:
tensorflowjs_converter --help
Chuyển đổi tập tin được tạo
Tập lệnh chuyển đổi ở trên tạo ra hai loại tệp:
-
model.json
: Biểu đồ luồng dữ liệu và bảng kê khai trọng số -
group1-shard\*of\*
: Tập hợp các tệp trọng số nhị phân
Ví dụ: đây là kết quả từ việc chuyển đổi MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
Bước 2: Load và chạy trên trình duyệt
- Cài đặt gói npm tfjs-converter:
yarn add @tensorflow/tfjs
hoặc npm install @tensorflow/tfjs
- Khởi tạo lớp FrozenModel và chạy suy luận.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
Hãy xem bản demo MobileNet .
API loadGraphModel
chấp nhận tham số LoadOptions
bổ sung, tham số này có thể được sử dụng để gửi thông tin xác thực hoặc tiêu đề tùy chỉnh cùng với yêu cầu. Để biết chi tiết, hãy xem tài liệu LoadGraphModel() .
Các hoạt động được hỗ trợ
Hiện tại, TensorFlow.js hỗ trợ một số hoạt động TensorFlow có giới hạn. Nếu mô hình của bạn sử dụng op không được hỗ trợ, tập lệnh tensorflowjs_converter
sẽ không thành công và in ra danh sách các op không được hỗ trợ trong mô hình của bạn. Vui lòng gửi vấn đề cho từng hoạt động để cho chúng tôi biết bạn cần hỗ trợ cho hoạt động nào.
Chỉ tải trọng lượng
Nếu bạn chỉ muốn tải trọng số, bạn có thể sử dụng đoạn mã sau:
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...