API Lớp của TensorFlow.js được mô phỏng theo Keras và chúng tôi cố gắng làm cho API Lớp tương tự như Keras một cách hợp lý dựa trên sự khác biệt giữa JavaScript và Python. Điều này giúp người dùng có kinh nghiệm phát triển mô hình Keras trong Python di chuyển sang Lớp TensorFlow.js trong JavaScript dễ dàng hơn. Ví dụ: mã Keras sau đây dịch sang JavaScript:
# Python:
import keras
import numpy as np
# Build and compile model.
model = keras.Sequential()
model.add(keras.layers.Dense(units=1, input_shape=[1]))
model.compile(optimizer='sgd', loss='mean_squared_error')
# Generate some synthetic data for training.
xs = np.array([[1], [2], [3], [4]])
ys = np.array([[1], [3], [5], [7]])
# Train model with fit().
model.fit(xs, ys, epochs=1000)
# Run inference with predict().
print(model.predict(np.array([[5]])))
// JavaScript:
import * as tf from '@tensorflow/tfjs';
// Build and compile model.
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
// Generate some synthetic data for training.
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
const ys = tf.tensor2d([[1], [3], [5], [7]], [4, 1]);
// Train model with fit().
await model.fit(xs, ys, {epochs: 1000});
// Run inference with predict().
model.predict(tf.tensor2d([[5]], [1, 1])).print();
Tuy nhiên, có một số khác biệt mà chúng tôi muốn chỉ ra và giải thích trong tài liệu này. Khi bạn hiểu những khác biệt này và lý do đằng sau chúng, quá trình di chuyển Python sang JavaScript (hoặc di chuyển theo hướng ngược lại) của bạn sẽ là một trải nghiệm tương đối suôn sẻ.
Các nhà xây dựng lấy Đối tượng JavaScript làm cấu hình
So sánh các dòng Python và JavaScript sau từ ví dụ trên: cả hai đều tạo ra một Lớp dày đặc .
# Python:
keras.layers.Dense(units=1, inputShape=[1])
// JavaScript:
tf.layers.dense({units: 1, inputShape: [1]});
Các hàm JavaScript không có đối số từ khóa tương đương trong các hàm Python. Chúng tôi muốn tránh triển khai các tùy chọn hàm tạo làm đối số vị trí trong JavaScript, điều này sẽ đặc biệt khó nhớ và khó sử dụng đối với các hàm tạo có số lượng lớn đối số từ khóa (ví dụ: LSTM ). Đây là lý do tại sao chúng tôi sử dụng các đối tượng cấu hình JavaScript. Các đối tượng như vậy cung cấp cùng mức độ bất biến về vị trí và tính linh hoạt như các đối số từ khóa Python.
Một số phương thức của lớp Model, ví dụ: Model.compile()
, cũng lấy đối tượng cấu hình JavaScript làm đầu vào. Tuy nhiên, hãy nhớ rằng Model.fit()
, Model.evaluate()
và Model.predict()
hơi khác nhau. Vì các phương pháp này lấy dữ liệu x
(tính năng) và y
(nhãn hoặc mục tiêu) bắt buộc làm đầu vào; x
và y
là các đối số vị trí tách biệt với đối tượng cấu hình tiếp theo đóng vai trò đối số từ khóa. Ví dụ:
// JavaScript:
await model.fit(xs, ys, {epochs: 1000});
Model.fit() không đồng bộ
Model.fit()
là phương pháp chính để người dùng thực hiện đào tạo mô hình trong TensorFlow.js. Phương pháp này thường có thể kéo dài, kéo dài trong vài giây hoặc vài phút. Do đó, chúng tôi sử dụng tính năng async
của ngôn ngữ JavaScript để có thể sử dụng chức năng này theo cách không chặn luồng giao diện người dùng chính khi chạy trên trình duyệt. Điều này tương tự với các hàm có khả năng chạy lâu khác trong JavaScript, chẳng hạn như tìm nạp async
. Lưu ý rằng async
là cấu trúc không tồn tại trong Python. Trong khi phương thức fit()
trong Keras trả về một đối tượng History, thì phương thức tương tự của phương thức fit()
trong JavaScript trả về Promise of History, có thể chờ ed (như trong ví dụ trên) hoặc được sử dụng với phương thức then().
Không có NumPy cho TensorFlow.js
Người dùng Python Keras thường sử dụng NumPy để thực hiện các phép toán số và mảng cơ bản, chẳng hạn như tạo các tensor 2D trong ví dụ trên.
# Python:
xs = np.array([[1], [2], [3], [4]])
Trong TensorFlow.js, loại phép toán số cơ bản này được thực hiện bằng chính gói đó. Ví dụ:
// JavaScript:
const xs = tf.tensor2d([[1], [2], [3], [4]], [4, 1]);
Không gian tên tf.*
cũng cung cấp một số hàm khác cho các phép toán mảng và đại số tuyến tính chẳng hạn như phép nhân ma trận. Xem tài liệu cốt lõi của TensorFlow.js để biết thêm thông tin.
Sử dụng các phương thức xuất xưởng, không phải hàm tạo
Dòng này trong Python (từ ví dụ trên) là một lệnh gọi hàm tạo:
# Python:
model = keras.Sequential()
Nếu được dịch hoàn toàn sang JavaScript, lệnh gọi hàm tạo tương đương sẽ giống như sau:
// JavaScript:
const model = new tf.Sequential(); // !!! DON'T DO THIS !!!
Tuy nhiên, chúng tôi quyết định không sử dụng hàm tạo “mới” vì 1) từ khóa “mới” sẽ làm cho mã trở nên cồng kềnh hơn và 2) hàm tạo “mới” bị coi là “phần xấu” của JavaScript: một cạm bẫy tiềm ẩn, vì được tranh luận bằng JavaScript: the Good Parts . Để tạo mô hình và lớp trong TensorFlow.js, bạn gọi các phương thức xuất xưởng có tên lowCamelCase, ví dụ:
// JavaScript:
const model = tf.sequential();
const layer = tf.layers.batchNormalization({axis: 1});
Giá trị chuỗi tùy chọn là lowCamelCase, không phải snake_case
Trong JavaScript, việc sử dụng kiểu chữ lạc đà cho tên biểu tượng là phổ biến hơn (ví dụ: xem Hướng dẫn về kiểu JavaScript của Google ), so với Python, trong đó kiểu chữ rắn là phổ biến (ví dụ: trong Keras). Do đó, chúng tôi đã quyết định sử dụng lowCamelCase cho các giá trị chuỗi cho các tùy chọn bao gồm:
- DataFormat, ví dụ:
channelsFirst
thaychannels_first
- Trình khởi tạo, ví dụ:
glorotNormal
thay vìglorot_normal
- Mất mát và số liệu, ví dụ:
meanSquaredError
thay vìmean_squared_error
,categoricalCrossentropy
thay vìcategorical_crossentropy
.
Ví dụ, như trong ví dụ trên:
// JavaScript:
model.compile({optimizer: 'sgd', loss: 'meanSquaredError'});
Về vấn đề tuần tự hóa và giải tuần tự hóa mô hình, bạn hãy yên tâm. Cơ chế nội bộ của TensorFlow.js đảm bảo rằng các trường hợp rắn trong đối tượng JSON được xử lý chính xác, ví dụ: khi tải các mô hình được đào tạo trước từ Python Keras.
Chạy các đối tượng Layer bằng apply(), không phải bằng cách gọi chúng là hàm
Trong Keras, đối tượng Layer có phương thức __call__
được xác định. Do đó, người dùng có thể gọi logic của lớp bằng cách gọi đối tượng dưới dạng hàm, ví dụ:
# Python:
my_input = keras.Input(shape=[2, 4])
flatten = keras.layers.Flatten()
print(flatten(my_input).shape)
Đường cú pháp Python này được triển khai dưới dạng phương thức apply() trong TensorFlow.js:
// JavaScript:
const myInput = tf.input({shape: [2, 4]});
const flatten = tf.layers.flatten();
console.log(flatten.apply(myInput).shape);
Layer.apply() hỗ trợ đánh giá mệnh lệnh (háo hức) trên các tensor bê tông
Hiện tại, trong Keras, phương thức gọi chỉ có thể hoạt động trên (Python) các đối tượng tf.Tensor
của TensorFlow (giả sử chương trình phụ trợ TensorFlow), mang tính biểu tượng và không chứa các giá trị số thực tế. Đây là những gì được thể hiện trong ví dụ ở phần trước. Tuy nhiên, trong TensorFlow.js, phương thức apply() của các lớp có thể hoạt động ở cả chế độ tượng trưng và mệnh lệnh. Nếu apply()
được gọi bằng SymbolicTensor (tương tự tf.Tensor), giá trị trả về sẽ là SymbolicTensor. Điều này thường xảy ra trong quá trình xây dựng mô hình. Nhưng nếu apply()
được gọi với giá trị Tensor cụ thể thực tế, nó sẽ trả về một Tensor cụ thể. Ví dụ:
// JavaScript:
const flatten = tf.layers.flatten();
flatten.apply(tf.ones([2, 3, 4])).print();
Tính năng này gợi nhớ đến Eager Execution của Python (Python). Nó mang lại khả năng tương tác và khả năng sửa lỗi cao hơn trong quá trình phát triển mô hình, ngoài ra còn mở ra cánh cửa để tạo nên các mạng lưới thần kinh động.
Trình tối ưu hóa đang được đào tạo. , không phải trình tối ưu hóa.
Trong Keras, các hàm tạo cho đối tượng Optimizer nằm trong không gian tên keras.optimizers.*
. Trong Lớp TensorFlow.js, các phương thức ban đầu dành cho Trình tối ưu hóa nằm trong không gian tên tf.train.*
. Ví dụ:
# Python:
my_sgd = keras.optimizers.sgd(lr=0.2)
// JavaScript:
const mySGD = tf.train.sgd({lr: 0.2});
LoadLayersModel() tải từ URL chứ không phải tệp HDF5
Trong Keras, các mô hình thường được lưu dưới dạng tệp HDF5 (.h5), tệp này có thể được tải sau bằng phương thức keras.models.load_model()
. Phương thức này có đường dẫn đến tệp .h5. Bản sao của load_model()
trong TensorFlow.js là tf.loadLayersModel()
. Vì HDF5 không phải là định dạng tệp thân thiện với trình duyệt nên tf.loadLayersModel()
có định dạng dành riêng cho TensorFlow.js. tf.loadLayersModel()
lấy tệp model.json làm đối số đầu vào. Model.json có thể được chuyển đổi từ tệp Keras HDF5 bằng gói pip tensorflowjs.
// JavaScript:
const model = await tf.loadLayersModel('https://foo.bar/model.json');
Cũng lưu ý rằng tf.loadLayersModel()
trả về Promise
là tf.Model
.
Nói chung, việc lưu và tải tf.Model
trong TensorFlow.js được thực hiện bằng cách sử dụng các phương thức tf.Model.save
và tf.loadLayersModel
tương ứng. Chúng tôi đã thiết kế các API này tương tự như API save và Load_model của Keras. Nhưng môi trường trình duyệt khá khác với môi trường phụ trợ chạy các khung học sâu chủ yếu như Keras, đặc biệt là trong mảng các tuyến đường để duy trì và truyền dữ liệu. Do đó, có một số khác biệt thú vị giữa các API lưu/tải trong TensorFlow.js và trong Keras. Xem hướng dẫn của chúng tôi về Lưu và tải tf.Model để biết thêm chi tiết.
Sử dụng fitDataset()
để huấn luyện mô hình bằng cách sử dụng đối tượng tf.data.Dataset
Trong tf.keras của Python TensorFlow, một mô hình có thể được huấn luyện bằng cách sử dụng đối tượng Dataset . Phương thức fit()
của mô hình chấp nhận trực tiếp một đối tượng như vậy. Một mô hình TensorFlow.js cũng có thể được huấn luyện bằng JavaScript tương đương với các đối tượng Dataset (xem tài liệu về API tf.data trong TensorFlow.js ). Tuy nhiên, không giống như Python, việc đào tạo dựa trên Tập dữ liệu được thực hiện thông qua một phương pháp chuyên dụng, cụ thể là fitDataset . Phương thức fit() chỉ dành cho việc huấn luyện mô hình dựa trên tensor.
Quản lý bộ nhớ của các đối tượng Lớp và Mô hình
TensorFlow.js chạy trên WebGL trong trình duyệt, trong đó trọng số của các đối tượng Lớp và Mô hình được hỗ trợ bởi kết cấu WebGL. Tuy nhiên, WebGL không có hỗ trợ thu gom rác tích hợp. Các đối tượng Lớp và Mô hình quản lý nội bộ bộ nhớ tensor cho người dùng trong các lệnh gọi suy luận và huấn luyện của họ. Nhưng chúng cũng cho phép người dùng loại bỏ chúng để giải phóng bộ nhớ WebGL mà chúng chiếm giữ. Điều này hữu ích trong trường hợp nhiều phiên bản mô hình được tạo và phát hành trong một lần tải trang. Để loại bỏ một đối tượng Layer hoặc Model, hãy sử dụng phương thức dispose()
.