개요
참고: 이 API는 새로운 것이며 pip install tf-nightly를 통해서만 사용할 수 있습니다. TensorFlow 버전 2.7에서 사용할 수 있습니다. 또한 API는 아직 실험적이며 변경될 수 있습니다.
이 CodeLab은 Jax를 사용하여 MNIST 인식을 위한 모델을 구축하는 방법과 이를 TensorFlow Lite로 변환하는 방법을 보여줍니다. 이 코드랩은 또한 훈련 후 양자화를 사용하여 Jax 변환 TFLite 모델을 최적화하는 방법을 보여줍니다.
TensorFlow.org에서 보기 | Google Colab에서 실행 | GitHub에서 소스 보기 | 노트북 다운로드 |
전제 조건
최신 TensorFlow 야간 pip 빌드에서 이 기능을 사용하는 것이 좋습니다.
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade
데이터 준비
Keras 데이터셋으로 MNIST 데이터를 다운로드하고 전처리합니다.
import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
2022-12-14 20:08:59.063265: E tensorflow/tsl/lib/monitoring/collection_registry.cc:81] Cannot register 2 metrics with the same name: /tensorflow/core/bfc_allocator_delay
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 [==============================] - 0s 0us/step
Jax로 MNIST 모델 빌드
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
rng = random.PRNGKey(0)
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
모델 학습 및 평가
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Starting training... Epoch 0 in 2.79 sec Training set accuracy 0.8728833198547363 Test set accuracy 0.880299985408783 Epoch 1 in 2.31 sec Training set accuracy 0.8983833193778992 Test set accuracy 0.9047999978065491 Epoch 2 in 2.32 sec Training set accuracy 0.9102333188056946 Test set accuracy 0.9138000011444092 Epoch 3 in 2.37 sec Training set accuracy 0.9172333478927612 Test set accuracy 0.9218999743461609 Epoch 4 in 2.31 sec Training set accuracy 0.9224833250045776 Test set accuracy 0.9253999590873718 Epoch 5 in 2.31 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.9309999942779541 Epoch 6 in 2.33 sec Training set accuracy 0.9328166842460632 Test set accuracy 0.9334999918937683 Epoch 7 in 2.31 sec Training set accuracy 0.9360166788101196 Test set accuracy 0.9370999932289124 Epoch 8 in 2.32 sec Training set accuracy 0.939050018787384 Test set accuracy 0.939300000667572 Epoch 9 in 2.31 sec Training set accuracy 0.9425666928291321 Test set accuracy 0.9429000020027161
TFLite 모델로 변환합니다.
참고로 우리는
- params를
functools.partial
predict
func에 인라인합니다. jnp.zeros
빌드합니다. 이것은 Jax가 모델을 추적하는 데 사용되는 "자리 표시자" 텐서입니다.experimental_from_jax
호출합니다.
serving_func
는 목록으로 래핑됩니다.- 입력은 지정된 이름과 연결되고 목록에 래핑된 배열로 전달됩니다.
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
2022-12-14 20:09:33.454726: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format. 2022-12-14 20:09:33.454777: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency. 2022-12-14 20:09:33.454783: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges.
변환된 TFLite 모델 확인
변환된 모델의 결과를 Jax 모델과 비교하십시오.
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
모델 최적화
우리는 제공 할 것입니다 representative_dataset
모델을 최적화하기 위해 훈련 후 quantiztion을 할 수 있습니다.
def representative_dataset():
for i in range(1000):
x = train_images[i:i+1]
yield [x]
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
2022-12-14 20:09:33.930065: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format. 2022-12-14 20:09:33.930110: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency. 2022-12-14 20:09:33.930117: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges. 2022-12-14 20:09:34.206292: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:362] Ignored output_format. 2022-12-14 20:09:34.206341: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:365] Ignored drop_control_dependency. 2022-12-14 20:09:34.206348: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:371] Ignored change_concat_input_ranges. fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32
최적화된 모델 평가
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)
양자화된 모델 크기 비교
양자화된 모델이 원래 모델보다 4배 더 작은 것을 볼 수 있어야 합니다.
du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M jax_mnist.tflite 1.8M jax_mnist_quant.tflite