TensorFlow.org에서 보기 | Run in Google Colab | View source on GitHub | 노트북 다운로드하기 |
개요
TensorFlow 모델 최적화 도구 키트의 일부인 가중치 클러스터링에 대한 엔드 투 엔드 예제를 소개합니다.
기타 페이지
가중치 클러스터링에 대한 소개와 이를 사용해야 하는지 여부(지원 내용 포함)를 결정하려면 개요 페이지를 참조하세요.
16개의 클러스터로 모델을 완전하게 클러스터링하는 등 해당 사용 사례에 필요한 API를 빠르게 찾으려면 종합 가이드를 참조하세요.
내용
이 튜토리얼에서는 다음을 수행합니다.
- MNIST 데이터세트를 위한
tf.keras
모델을 처음부터 훈련합니다. - 가중치 클러스터링 API를 적용하여 모델을 미세 조정하고 정확성을 확인합니다.
- 클러스터링으로부터 6배 더 작은 TF 및 TFLite 모델을 만듭니다.
- 가중치 클러스터링과 훈련 후 양자화를 결합하여 8배 더 작은 TFLite 모델을 만듭니다.
- TF에서 TFLite로 정확성이 지속되는지 확인합니다.
설정
이 Jupyter 노트북은 로컬 virtualenv 또는 colab에서 실행할 수 있습니다. 종속성 설정에 대한 자세한 내용은 설치 가이드를 참조하세요.
pip install -q tensorflow-model-optimization
import tensorflow as tf
from tensorflow import keras
import numpy as np
import tempfile
import zipfile
import os
클러스터링을 사용하지 않고 MNIST용 tf.keras 모델 훈련하기
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0
# Define the model architecture.
model = keras.Sequential([
keras.layers.InputLayer(input_shape=(28, 28)),
keras.layers.Reshape(target_shape=(28, 28, 1)),
keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
keras.layers.MaxPooling2D(pool_size=(2, 2)),
keras.layers.Flatten(),
keras.layers.Dense(10)
])
# Train the digit classification model
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(
train_images,
train_labels,
validation_split=0.1,
epochs=10
)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step Epoch 1/10 1688/1688 [==============================] - 7s 3ms/step - loss: 0.5413 - accuracy: 0.8495 - val_loss: 0.1231 - val_accuracy: 0.9673 Epoch 2/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.1347 - accuracy: 0.9603 - val_loss: 0.0868 - val_accuracy: 0.9773 Epoch 3/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0935 - accuracy: 0.9726 - val_loss: 0.0695 - val_accuracy: 0.9822 Epoch 4/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0708 - accuracy: 0.9797 - val_loss: 0.0681 - val_accuracy: 0.9817 Epoch 5/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0616 - accuracy: 0.9813 - val_loss: 0.0611 - val_accuracy: 0.9835 Epoch 6/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0515 - accuracy: 0.9849 - val_loss: 0.0609 - val_accuracy: 0.9838 Epoch 7/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0488 - accuracy: 0.9856 - val_loss: 0.0562 - val_accuracy: 0.9855 Epoch 8/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0428 - accuracy: 0.9867 - val_loss: 0.0646 - val_accuracy: 0.9808 Epoch 9/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0379 - accuracy: 0.9887 - val_loss: 0.0617 - val_accuracy: 0.9848 Epoch 10/10 1688/1688 [==============================] - 4s 2ms/step - loss: 0.0338 - accuracy: 0.9904 - val_loss: 0.0553 - val_accuracy: 0.9853 <tensorflow.python.keras.callbacks.History at 0x7fde60ea4748>
기준 모델을 평가하고 나중에 사용할 수 있도록 저장하기
_, baseline_model_accuracy = model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
tf.keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9814000129699707 Saving model to: /tmp/tmpaxhvi0yg.h5
클러스터링을 사용하여 사전 훈련된 모델 미세 조정하기
사전 훈련된 전체 모델에 cluster_weights()
API를 적용하여 압축 후 적절한 정확성을 유지하면서 모델 크기가 줄어드는 효과를 입증합니다. 해당 사용 사례에서 정확성과 압축률의 균형을 가장 잘 유지하는 방법은 포괄적 가이드의 레이어별 예를 참조하세요.
모델 정의 및 클러스터링 API 적용하기
모델을 클러스터링 API로 전달하기 전에 모델이 훈련되었고 수용 가능한 정확성을 보이는지 확인합니다.
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 16,
'cluster_centroids_init': CentroidInitialization.LINEAR
}
# Cluster a whole model
clustered_model = cluster_weights(model, **clustering_params)
# Use smaller learning rate for fine-tuning clustered model
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
clustered_model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=opt,
metrics=['accuracy'])
clustered_model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= cluster_reshape (ClusterWeig (None, 28, 28, 1) 0 _________________________________________________________________ cluster_conv2d (ClusterWeigh (None, 26, 26, 12) 136 _________________________________________________________________ cluster_max_pooling2d (Clust (None, 13, 13, 12) 0 _________________________________________________________________ cluster_flatten (ClusterWeig (None, 2028) 0 _________________________________________________________________ cluster_dense (ClusterWeight (None, 10) 20306 ================================================================= Total params: 20,442 Trainable params: 54 Non-trainable params: 20,388 _________________________________________________________________
모델을 미세 조정하고 기준 대비 정확성 평가하기
하나의 epoch 동안 클러스터링이 있는 모델을 미세 조정합니다.
# Fine-tune model
clustered_model.fit(
train_images,
train_labels,
batch_size=500,
epochs=1,
validation_split=0.1)
108/108 [==============================] - 1s 5ms/step - loss: 0.0423 - accuracy: 0.9851 - val_loss: 0.0644 - val_accuracy: 0.9828 <tensorflow.python.keras.callbacks.History at 0x7fdecf7e9b38>
이 예의 경우, 기준과 비교하여 클러스터링 후 테스트 정확성의 손실이 미미합니다.
_, clustered_model_accuracy = clustered_model.evaluate(
test_images, test_labels, verbose=0)
print('Baseline test accuracy:', baseline_model_accuracy)
print('Clustered test accuracy:', clustered_model_accuracy)
Baseline test accuracy: 0.9814000129699707 Clustered test accuracy: 0.9786999821662903
클러스터링으로부터 6배 더 작은 모델 만들기
클러스터링의 압축 이점을 확인하려면 strip_clustering
과 표준 압축 알고리즘(예: gzip 이용) 적용이 모두 필요합니다.
먼저, TensorFlow를 위한 압축 가능한 모델을 만듭니다. 여기서, strip_clustering
은 훈련 중에만 클러스터링에 필요한 모든 변수(예: 클러스터 중심과 인덱스를 저장하기 위한 tf.Variable
)를 제거합니다. 이러한 변수를 제거하지 않으면 추론 중에 모델 크기가 증가하게 됩니다.
final_model = tfmot.clustering.keras.strip_clustering(clustered_model)
_, clustered_keras_file = tempfile.mkstemp('.h5')
print('Saving clustered model to: ', clustered_keras_file)
tf.keras.models.save_model(final_model, clustered_keras_file,
include_optimizer=False)
Saving clustered model to: /tmp/tmpvjj70etk.h5
그런 다음, TFLite를 위한 압축 가능한 모델을 만듭니다. 클러스터링된 모델을 대상 백엔드에서 실행 가능한 형식으로 변환할 수 있습니다. TensorFlow Lite는 모바일 기기에 배포하는 데 사용할 수 있는 예입니다.
clustered_tflite_file = '/tmp/clustered_mnist.tflite'
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
tflite_clustered_model = converter.convert()
with open(clustered_tflite_file, 'wb') as f:
f.write(tflite_clustered_model)
print('Saved clustered TFLite model to:', clustered_tflite_file)
INFO:tensorflow:Assets written to: /tmp/tmpg3n3nbv4/assets Saved clustered TFLite model to: /tmp/clustered_mnist.tflite
실제로 gzip을 통해 모델을 압축하는 도우미 함수를 정의하고 압축된 크기를 측정합니다.
def get_gzipped_model_size(file):
# It returns the size of the gzipped model in bytes.
import os
import zipfile
_, zipped_file = tempfile.mkstemp('.zip')
with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
f.write(file)
return os.path.getsize(zipped_file)
클러스터링으로부터 모델이 6배 더 작아진 것을 확인하세요.
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered Keras model: %.2f bytes" % (get_gzipped_model_size(clustered_keras_file)))
print("Size of gzipped clustered TFlite model: %.2f bytes" % (get_gzipped_model_size(clustered_tflite_file)))
Size of gzipped baseline Keras model: 77991.00 bytes Size of gzipped clustered Keras model: 12656.00 bytes Size of gzipped clustered TFlite model: 12178.00 bytes
가중치 클러스터링과 훈련 후 양자화를 결합하여 8배 더 작은 TFLite 모델 만들기
추가적인 이점을 얻기 위해 클러스터링한 모델에 훈련 후 양자화를 적용할 수 있습니다.
converter = tf.lite.TFLiteConverter.from_keras_model(final_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
_, quantized_and_clustered_tflite_file = tempfile.mkstemp('.tflite')
with open(quantized_and_clustered_tflite_file, 'wb') as f:
f.write(tflite_quant_model)
print('Saved quantized and clustered TFLite model to:', quantized_and_clustered_tflite_file)
print("Size of gzipped baseline Keras model: %.2f bytes" % (get_gzipped_model_size(keras_file)))
print("Size of gzipped clustered and quantized TFlite model: %.2f bytes" % (get_gzipped_model_size(quantized_and_clustered_tflite_file)))
INFO:tensorflow:Assets written to: /tmp/tmprx_pw7t5/assets INFO:tensorflow:Assets written to: /tmp/tmprx_pw7t5/assets Saved quantized and clustered TFLite model to: /tmp/tmpq1ux2lbh.tflite Size of gzipped baseline Keras model: 77991.00 bytes Size of gzipped clustered and quantized TFlite model: 9335.00 bytes
TF에서 TFLite로 정확성이 지속되는지 확인하기
테스트 데이터세트에서 TFLite 모델을 평가하는 도우미 함수를 정의합니다.
def eval_model(interpreter):
input_index = interpreter.get_input_details()[0]["index"]
output_index = interpreter.get_output_details()[0]["index"]
# Run predictions on every image in the "test" dataset.
prediction_digits = []
for i, test_image in enumerate(test_images):
if i % 1000 == 0:
print('Evaluated on {n} results so far.'.format(n=i))
# Pre-processing: add batch dimension and convert to float32 to match with
# the model's input data format.
test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
interpreter.set_tensor(input_index, test_image)
# Run inference.
interpreter.invoke()
# Post-processing: remove batch dimension and find the digit with highest
# probability.
output = interpreter.tensor(output_index)
digit = np.argmax(output()[0])
prediction_digits.append(digit)
print('\n')
# Compare prediction results with ground truth labels to calculate accuracy.
prediction_digits = np.array(prediction_digits)
accuracy = (prediction_digits == test_labels).mean()
return accuracy
클러스터링되고 양자화된 모델을 평가한 다음, TensorFlow의 정확성이 TFLite 백엔드까지 유지되는지 확인합니다.
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
test_accuracy = eval_model(interpreter)
print('Clustered and quantized TFLite test_accuracy:', test_accuracy)
print('Clustered TF test accuracy:', clustered_model_accuracy)
Evaluated on 0 results so far. Evaluated on 1000 results so far. Evaluated on 2000 results so far. Evaluated on 3000 results so far. Evaluated on 4000 results so far. Evaluated on 5000 results so far. Evaluated on 6000 results so far. Evaluated on 7000 results so far. Evaluated on 8000 results so far. Evaluated on 9000 results so far. Clustered and quantized TFLite test_accuracy: 0.9789 Clustered TF test accuracy: 0.9786999821662903
결론
이 튜토리얼에서는 TensorFlow 모델 최적화 도구 키트 API를 사용하여 클러스터링된 모델을 만드는 방법을 알아보았습니다. 구체적으로, 정확성 차이를 최소화하면서 8배 더 작은 MNIST용 모델을 생성하기 위한 엔드 투 엔드 예제를 살펴보았습니다. 리소스가 제한된 환경에서 배포할 때 특히 중요할 수 있는 이 새로운 기능을 한 번 사용해 보세요.