Introdução
Este guia demonstra como o Tensorflow Extended (TFX) pode criar e avaliar modelos de machine learning que serão implantados no dispositivo. O TFX agora oferece suporte nativo ao TFLite , o que possibilita realizar inferências altamente eficientes em dispositivos móveis.
Este guia orienta você nas alterações que podem ser feitas em qualquer pipeline para gerar e avaliar modelos TFLite. Fornecemos um exemplo completo aqui , demonstrando como o TFX pode treinar e avaliar modelos TFLite que são treinados no conjunto de dados MNIST . Além disso, mostramos como o mesmo pipeline pode ser usado para exportar simultaneamente tanto o SavedModel padrão baseado em Keras quanto o TFLite, permitindo aos usuários comparar a qualidade dos dois.
Presumimos que você esteja familiarizado com o TFX, nossos componentes e pipelines. Caso contrário, consulte este tutorial .
Passos
Apenas duas etapas são necessárias para criar e avaliar um modelo TFLite no TFX. A primeira etapa é invocar o reescritor TFLite no contexto do TFX Trainer para converter o modelo TensorFlow treinado em um modelo TFLite. A segunda etapa é configurar o Avaliador para avaliar modelos TFLite. Agora discutiremos cada um deles.
Invocando o reescritor TFLite dentro do Trainer.
O TFX Trainer espera que um run_fn
definido pelo usuário seja especificado em um arquivo de módulo. Este run_fn
define o modelo a ser treinado, treina-o para o número especificado de iterações e exporta o modelo treinado.
No restante desta seção, fornecemos trechos de código que mostram as alterações necessárias para invocar o reescritor TFLite e exportar um modelo TFLite. Todo esse código está localizado no run_fn
do módulo MNIST TFLite .
Conforme mostrado no código abaixo, devemos primeiro criar uma assinatura que receba um Tensor
para cada recurso como entrada. Observe que isso é diferente da maioria dos modelos existentes no TFX, que usam protos tf.Example serializados como entrada.
signatures = {
'serving_default':
_get_serve_tf_examples_fn(
model, tf_transform_output).get_concrete_function(
tf.TensorSpec(
shape=[None, 784],
dtype=tf.float32,
name='image_floats'))
}
Em seguida, o modelo Keras é salvo como SavedModel da mesma forma que normalmente.
temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp')
model.save(temp_saving_model_dir, save_format='tf', signatures=signatures)
Finalmente, criamos uma instância do reescritor TFLite ( tfrw
) e invocamos-o no SavedModel para obter o modelo TFLite. Armazenamos este modelo TFLite no serving_model_dir
fornecido pelo chamador do run_fn
. Dessa forma, o modelo TFLite é armazenado no local onde todos os componentes downstream do TFX esperam encontrar o modelo.
tfrw = rewriter_factory.create_rewriter(
rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter')
converters.rewrite_saved_model(temp_saving_model_dir,
fn_args.serving_model_dir,
tfrw,
rewriter.ModelType.TFLITE_MODEL)
Avaliando o modelo TFLite.
O TFX Evaluator oferece a capacidade de analisar modelos treinados para compreender sua qualidade em uma ampla variedade de métricas. Além de analisar SavedModels, o TFX Evaluator agora também é capaz de analisar modelos TFLite.
O trecho de código a seguir (reproduzido do pipeline MNIST ) mostra como configurar um avaliador que analisa um modelo TFLite.
# Informs the evaluator that the model is a TFLite model.
eval_config_lite.model_specs[0].model_type = 'tf_lite'
...
# Uses TFMA to compute the evaluation statistics over features of a TFLite
# model.
model_analyzer_lite = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer_lite.outputs['model'],
eval_config=eval_config_lite,
).with_id('mnist_lite')
Conforme mostrado acima, a única alteração que precisamos fazer é definir o campo model_type
como tf_lite
. Nenhuma outra alteração de configuração é necessária para analisar o modelo TFLite. Independentemente de ser analisado um modelo TFLite ou um SavedModel, a saída do Evaluator
terá exatamente a mesma estrutura.
No entanto, observe que o Avaliador assume que o modelo TFLite é salvo em um arquivo chamado tflite
dentro de trainer_lite.outputs['model'].