TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
TensorFlow Lite(TFLite)は、開発者がデバイス(モバイル、組み込み、IoT デバイス)で ML 推論を実行するのに役立つ一連のツールです。TFLite コンバータは、既存の TF モデルをデバイス上で効率的に実行できる最適化された TFLite モデル形式に変換するツールの 1 つです。
このドキュメントでは、TF から TFLite への変換コードにどのような変更を加える必要があるかを説明し、いくつかの例を示します。
TF から TFLite への変換コードの変更
従来の TF1 モデル形式(Keras ファイル、凍結された GraphDef、チェックポイント、tf.Session など)を使用している場合は、それを TF1/TF2 SavedModel に更新し、TF2 コンバータ API
tf.lite.TFLiteConverter.from_saved_model(...)
を使用して TFLite モデルに変換します(表 1 を参照)。コンバータ API フラグを更新します(表 2 を参照)。
tflite.constants
などのレガシー API を削除します。(例:tf.lite.constants.INT8
をtf.int8
に置き換えます)
// 表 1 // TFLite Python コンバータ API の更新
TF1 API | TF2 API |
---|---|
tf.lite.TFLiteConverter.from_saved_model('saved_model/',..) |
サポートされています |
tf.lite.TFLiteConverter.from_keras_model_file('model.h5',..) |
削除(SavedModel 形式に更新) |
tf.lite.TFLiteConverter.from_frozen_graph('model.pb',..) |
削除(SavedModel 形式に更新) |
tf.lite.TFLiteConverter.from_session(sess,...) |
削除(SavedModel 形式に更新) |
<style> .table {margin-left: 0 !important;} </style>
// 表 2 // TFLite Python コンバータ API フラグの更新
TF1 API | TF2 API |
---|---|
allow_custom_ops optimizations representative_dataset target_spec inference_input_type inference_output_type experimental_new_converter experimental_new_quantizer |
サポートされています |
input_tensors output_tensors input_arrays_with_shape output_arrays experimental_debug_info_func |
削除(コンバータ API 引数がサポートされていません) |
change_concat_input_ranges default_ranges_stats get_input_arrays() inference_type quantized_input_stats reorder_across_fake_quant |
削除(量子化ワークフローがサポートされていません) |
conversion_summary_dir dump_graphviz_dir dump_graphviz_video |
削除(モデルを可視化するには、Netron または visualize.py を使用します) |
output_format drop_control_dependency |
削除(TF2 でサポートされていない機能) |
例
ここでは、レガシー TF1 モデルを TF1/TF2 SavedModel に変換し、それらを TF2 TFLite モデルに変換するいくつかの例について説明します。
セットアップ
必要な TensorFlow インポートから始めます。
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import numpy as np
import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)
import shutil
def remove_dir(path):
try:
shutil.rmtree(path)
except:
pass
2024-01-11 18:10:09.106741: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:10:09.106786: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:10:09.108429: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
必要なすべての TF1 モデル形式を作成します。
# Create a TF1 SavedModel
SAVED_MODEL_DIR = "tf_saved_model/"
remove_dir(SAVED_MODEL_DIR)
with tf1.Graph().as_default() as g:
with tf1.Session() as sess:
input = tf1.placeholder(tf.float32, shape=(3,), name='input')
output = input + 2
# print("result: ", sess.run(output, {input: [0., 2., 4.]}))
tf1.saved_model.simple_save(
sess, SAVED_MODEL_DIR,
inputs={'input': input},
outputs={'output': output})
print("TF1 SavedModel path: ", SAVED_MODEL_DIR)
# Create a TF1 Keras model
KERAS_MODEL_PATH = 'tf_keras_model.h5'
model = tf1.keras.models.Sequential([
tf1.keras.layers.InputLayer(input_shape=(128, 128, 3,), name='input'),
tf1.keras.layers.Dense(units=16, input_shape=(128, 128, 3,), activation='relu'),
tf1.keras.layers.Dense(units=1, name='output')
])
model.save(KERAS_MODEL_PATH, save_format='h5')
print("TF1 Keras Model path: ", KERAS_MODEL_PATH)
# Create a TF1 frozen GraphDef model
GRAPH_DEF_MODEL_PATH = tf.keras.utils.get_file(
'mobilenet_v1_0.25_128',
origin='https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz',
untar=True,
) + '/frozen_graph.pb'
print("TF1 frozen GraphDef path: ", GRAPH_DEF_MODEL_PATH)
TF1 SavedModel path: tf_saved_model/ TF1 Keras Model path: tf_keras_model.h5 Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz 2617289/2617289 [==============================] - 0s 0us/step /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/training.py:3103: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native Keras format, e.g. `model.save('my_model.keras')`. saving_api.save_model( TF1 frozen GraphDef path: /home/kbuilder/.keras/datasets/mobilenet_v1_0.25_128/frozen_graph.pb
1. TF1 SavedModel を TFLite モデルに変換する
以前: TF1 で変換する
以下は、TF1 スタイルの TFlite 変換の典型的なコードです。
converter = tf1.lite.TFLiteConverter.from_saved_model(
saved_model_dir=SAVED_MODEL_DIR,
input_arrays=['input'],
input_shapes={'input' : [3]}
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
# Ignore warning: "Use '@tf.function' or '@defun' to decorate the function."
2024-01-11 18:10:13.887055: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2024-01-11 18:10:13.887089: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2024-01-11 18:10:13.887097: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 1, Total Ops 5, % non-converted = 20.00 % * 1 ARITH ops - arith.constant: 1 occurrences (f32: 1) (f32: 1)
更新後: TF2 で変換する
TF1 SavedModel を TFLite モデルに直接変換し、より小さい v2 コンバータフラグを設定します。
# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=SAVED_MODEL_DIR)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()
2024-01-11 18:10:13.950640: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2024-01-11 18:10:13.950675: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 1, Total Ops 5, % non-converted = 20.00 % * 1 ARITH ops - arith.constant: 1 occurrences (f32: 1) (f32: 1)
2. TF1 Keras モデルファイルを TFLite モデルに変換する
以前: TF1 で変換する
これは、TF1 スタイルの TFlite 変換の典型的なコードです。
converter = tf1.lite.TFLiteConverter.from_keras_model_file(model_file=KERAS_MODEL_PATH)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
2024-01-11 18:10:14.922921: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2024-01-11 18:10:14.922955: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2024-01-11 18:10:14.922961: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 9, Total Ops 35, % non-converted = 25.71 % * 9 ARITH ops - arith.constant: 9 occurrences (f32: 4, i32: 5) (f32: 2) (i32: 2) (f32: 2) (i32: 4) (i32: 2) (i32: 4) (f32: 4) (i32: 2)
更新後: TF2 で変換する
TF1 Keras モデルファイルを TF2 SavedModel に変換してから、より小さな v2 コンバータフラグを設定し、それを TFLite モデルに変換します。
# Convert TF1 Keras model file to TF2 SavedModel.
model = tf.keras.models.load_model(KERAS_MODEL_PATH)
model.save(filepath='saved_model_2/')
# Convert TF2 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_2/')
tflite_model = converter.convert()
2024-01-11 18:10:15.434771: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2024-01-11 18:10:15.434806: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 9, Total Ops 35, % non-converted = 25.71 % * 9 ARITH ops - arith.constant: 9 occurrences (f32: 4, i32: 5) (f32: 2) (i32: 2) (f32: 2) (i32: 4) (i32: 2) (i32: 4) (f32: 4) (i32: 2)
3. TF1 で凍結された GraphDef を TFLite モデルに変換する
以前: TF1 で変換する
これは、TF1 スタイルの TFlite 変換の典型的なコードです。
converter = tf1.lite.TFLiteConverter.from_frozen_graph(
graph_def_file=GRAPH_DEF_MODEL_PATH,
input_arrays=['input'],
input_shapes={'input' : [1, 128, 128, 3]},
output_arrays=['MobilenetV1/Predictions/Softmax'],
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
2024-01-11 18:10:15.649598: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2024-01-11 18:10:15.649630: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2024-01-11 18:10:15.649636: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 38, Total Ops 91, % non-converted = 41.76 % * 38 ARITH ops - arith.constant: 38 occurrences (f32: 37, i32: 1) (f32: 1) (f32: 15) (f32: 13) (uq_8: 19) (f32: 1) (f32: 1)
更新後: TF2 で変換する
TF1 で凍結された GraphDef を TF1 SavedModel に変換してから、より小さな v2 コンバータフラグを設定して、それを TFLite モデルに変換します。
## Convert TF1 frozen Graph to TF1 SavedModel.
# Load the graph as a v1.GraphDef
import pathlib
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(pathlib.Path(GRAPH_DEF_MODEL_PATH).read_bytes())
# Convert the GraphDef to a tf.Graph
with tf.Graph().as_default() as g:
tf.graph_util.import_graph_def(gdef, name="")
# Look up the input and output tensors.
input_tensor = g.get_tensor_by_name('input:0')
output_tensor = g.get_tensor_by_name('MobilenetV1/Predictions/Softmax:0')
# Save the graph as a TF1 Savedmodel
remove_dir('saved_model_3/')
with tf.compat.v1.Session(graph=g) as s:
tf.compat.v1.saved_model.simple_save(
session=s,
export_dir='saved_model_3/',
inputs={'input':input_tensor},
outputs={'output':output_tensor})
# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_3/')
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()
2024-01-11 18:10:16.454604: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2024-01-11 18:10:16.454643: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 38, Total Ops 91, % non-converted = 41.76 % * 38 ARITH ops - arith.constant: 38 occurrences (f32: 37, i32: 1) (f32: 1) (f32: 15) (f32: 13) (uq_8: 19) (f32: 1) (f32: 1)
参考資料
- ワークフローと最新機能の詳細については、TFLite ガイドを参照してください。
- TF1 コードまたは従来の TF1 モデル形式(Keras
.h5
ファイル、凍結された GraphDef.pb
など)を使用している場合、コードを更新し、モデルを TF2 SavedModel 形式に移行してください。