TensorFlow.org で表示 | Google Colab で実行 | GitHubでソースを表示 | ノートブックをダウンロード |
TensorFlow Lite では、TensorFlow モデルの入出力仕様を TensorFlow Lite モデルに変換できます。入出力仕様は「シグネチャ」と呼ばれます。署名は、SavedModel の構築時または具体的な関数の作成時に指定できます。
TensorFlow Lite のシグネチャには次の機能があります。
- TensorFlow モデルのシグネチャを守ることで、変換された TensorFlow Lite モデルの入出力を指定する。
- 1 つの TensorFlow Lite モデルで複数の入力点をサポートできる。
シグネチャには次の 3 つの要素があります。
- 入力: シグネチャの入力名から入力テンソルへの入力のマッピング。
- 出力: シグネチャの出力名から出力テンソルへの出力のマッピング。
- シグネチャキー: グラフの入力点を識別する名前。
MNIST モデルをビルドする
import tensorflow as tf
サンプル モデル
エンコードとデコードといった 2 つのタスクが TensorFlow モデルとして存在するとします。
class Model(tf.Module):
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def encode(self, x):
result = tf.strings.as_string(x)
return {
"encoded_result": result
}
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def decode(self, x):
result = tf.strings.to_number(x)
return {
"decoded_result": result
}
シグネチャという点では、上記の TensorFlow モデルは次のように要約することができます。
シグネチャ
- キー: encode
- 入力: {"x"}
- 出力: {"encoded_result"}
シグネチャ
- キー: decode
- 入力: {"x"}
- 出力: {"decoded_result"}
シグネチャを使用したモデルの変換
TensorFlow Lite コンバータ API は、上記のシグネチャ情報を変換された TensorFlow Lite モデルに渡します。
この変換機能は、TensorFlow バージョン 2.7.0 以降のすべてのコンバータ API で提供されています。使用例を参照してください。
保存されたモデルから変換
model = Model()
# Save the model
SAVED_MODEL_PATH = 'content/saved_models/coding'
tf.saved_model.save(
model, SAVED_MODEL_PATH,
signatures={
'encode': model.encode.get_concrete_function(),
'decode': model.decode.get_concrete_function()
})
# Convert the saved model using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_saved_model(SAVED_MODEL_PATH)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
Keras モデルから変換
# Generate a Keras model.
keras_model = tf.keras.Sequential(
[
tf.keras.layers.Dense(2, input_dim=4, activation='relu', name='x'),
tf.keras.layers.Dense(1, activation='relu', name='output'),
]
)
# Convert the keras model using TFLiteConverter.
# Keras model converter API uses the default signature automatically.
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
Concrete 関数から変換
model = Model()
# Convert the concrete functions using TFLiteConverter
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[model.encode.get_concrete_function(),
model.decode.get_concrete_function()], model)
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
tflite_model = converter.convert()
# Print the signatures from the converted model
interpreter = tf.lite.Interpreter(model_content=tflite_model)
signatures = interpreter.get_signature_list()
print(signatures)
シグネチャの実行
TensorFlow の推論 API は、シグネチャに基づく実行をサポートします。
- シグネチャで指定された入出力の名前を使用して、入出力テンソルにアクセスします。
- シグネチャキーで指定されたグラフの各入力点を個別に実行します。
- SavedModel の初期化手順をサポートします。
Java、C++、Python 言語バインディングは現在使用できます。次のセクションの例を参照してください。
Java
try (Interpreter interpreter = new Interpreter(file_of_tensorflowlite_model)) {
// Run encoding signature.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", input);
Map<String, Object> outputs = new HashMap<>();
outputs.put("encoded_result", encoded_result);
interpreter.runSignature(inputs, outputs, "encode");
// Run decoding signature.
Map<String, Object> inputs = new HashMap<>();
inputs.put("x", encoded_result);
Map<String, Object> outputs = new HashMap<>();
outputs.put("decoded_result", decoded_result);
interpreter.runSignature(inputs, outputs, "decode");
}
C++
SignatureRunner* encode_runner =
interpreter->GetSignatureRunner("encode");
encode_runner->ResizeInputTensor("x", {100});
encode_runner->AllocateTensors();
TfLiteTensor* input_tensor = encode_runner->input_tensor("x");
float* input = GetTensorData<float>(input_tensor);
// Fill `input`.
encode_runner->Invoke();
const TfLiteTensor* output_tensor = encode_runner->output_tensor(
"encoded_result");
float* output = GetTensorData<float>(output_tensor);
// Access `output`.
Python
# Load the TFLite model in TFLite Interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)
# Print the signatures from the converted model
signatures = interpreter.get_signature_list()
print('Signature:', signatures)
# encode and decode are callable with input as arguments.
encode = interpreter.get_signature_runner('encode')
decode = interpreter.get_signature_runner('decode')
# 'encoded' and 'decoded' are dictionaries with all outputs from the inference.
input = tf.constant([1, 2, 3], dtype=tf.float32)
print('Input:', input)
encoded = encode(x=input)
print('Encoded result:', encoded)
decoded = decode(x=encoded['encoded_result'])
print('Decoded result:', decoded)
既知の制限
- TFLite インタープリタはスレッドの安全を保証しないため、同じインタープリタからのシグネチャランナーは同時に実行されません。
- C/iOS/Swift のサポートはまだ提供されていません。
更新
- バージョン 2.7
- 複数のシグネチャ機能が実装されました。
- バージョン 2 以降のすべてのコンバータ API は、シグネチャ対応 TensorFlow Lite モデルを生成します。
- バージョン 2.5
- シグネチャ機能は、
from_saved_model
コンバータ API から利用できます。
- シグネチャ機能は、