tf.feature_column
を Keras 前処理レイヤーに移行する" />
TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
通常、モデルのトレーニングには、特に構造化データを扱う場合に、特徴量の前処理が必要となることがあります。TensorFlow 1 で tf.estimator.Estimator
をトレーニングする場合、通常、tf.feature_column
API を使用して特徴量の前処理を実行します。TensorFlow 2 では、Keras 前処理レイヤーで直接実行できます。
この移行ガイドでは、特徴量カラムと前処理レイヤーの両方を使用した一般的な特徴量変換を紹介し、両方の API を使用して完全なモデルをトレーニングします。
まず、必要なものをインポートします。
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import math
2024-01-11 18:05:46.202664: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:05:46.202711: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:05:46.204283: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
次に、デモのために特徴量カラムを呼び出すためのユーティリティ関数を追加します。
def call_feature_columns(feature_columns, inputs):
# This is a convenient way to call a `feature_column` outside of an estimator
# to display its output.
feature_layer = tf1.keras.layers.DenseFeatures(feature_columns)
return feature_layer(inputs)
入力処理
Estimator で特徴量カラムを使用するには、モデル入力は常にテンソルのディクショナリであることが期待されます。
input_dict = {
'foo': tf.constant([1]),
'bar': tf.constant([0]),
'baz': tf.constant([-1])
}
各特徴量カラムは、ソースデータにインデックスを付けるためのキーを使用して作成する必要があります。すべての特徴量カラムの出力は連結され、Estimator モデルによって使用されます。
columns = [
tf1.feature_column.numeric_column('foo'),
tf1.feature_column.numeric_column('bar'),
tf1.feature_column.numeric_column('baz'),
]
call_feature_columns(columns, input_dict)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/3124623333.py:2: numeric_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. <tf.Tensor: shape=(1, 3), dtype=float32, numpy=array([[ 0., -1., 1.]], dtype=float32)>
Keras では、モデル入力はより柔軟です。tf.keras.Model
は、単一のテンソル入力、テンソル特徴量のリスト、またはテンソル特徴量のディクショナリを処理できます。モデルの作成時に tf.keras.Input
のディクショナリを渡すことで、ディクショナリの入力を処理できます。入力は自動的に連結されないため、より柔軟な方法で使用できます。これらは tf.keras.layers.Concatenate
で連結できます。
inputs = {
'foo': tf.keras.Input(shape=()),
'bar': tf.keras.Input(shape=()),
'baz': tf.keras.Input(shape=()),
}
# Inputs are typically transformed by preprocessing layers before concatenation.
outputs = tf.keras.layers.Concatenate()(inputs.values())
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model(input_dict)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 1., 0., -1.], dtype=float32)>
One-hot エンコーディングの整数 ID
一般的に、既知の範囲の整数入力を One-hot エンコードすることにより特徴量を変換できます。特徴量カラムを使用した例を次に示します。
categorical_col = tf1.feature_column.categorical_column_with_identity(
'type', num_buckets=3)
indicator_col = tf1.feature_column.indicator_column(categorical_col)
call_feature_columns(indicator_col, {'type': [0, 1, 2]})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/1369923821.py:1: categorical_column_with_identity (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/1369923821.py:3: indicator_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)>
Keras 前処理レイヤーを使用すると、これらのカラムを output_mode
を 'one_hot'
に設定した単一の tf.keras.layers.CategoryEncoding
レイヤーに置き換えることができます。
one_hot_layer = tf.keras.layers.CategoryEncoding(
num_tokens=3, output_mode='one_hot')
one_hot_layer([0, 1, 2])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)>
注意: 大規模な One-hot エンコーディングの場合、出力のスパース表現を使用する方がはるかに効率的です。sparse=True
を CategoryEncoding
レイヤーに渡すと、レイヤーの出力は tf.sparse.SparseTensor
になり、効率的に tf.keras.layers.Dense
レイヤーへの入力として処理されます。
数値的特徴量の正規化
特徴量カラムを持つ連続浮動小数点特徴量を処理する場合、tf.feature_column.numeric_column
を使用する必要があります。入力が既に正規化されている場合、これを Keras に変換するのは簡単です。上記のように、tf.keras.Input
をモデルに直接使用するだけです。
numeric_column
を使用して入力を正規化することもできます。
def normalize(x):
mean, variance = (2.0, 1.0)
return (x - mean) / math.sqrt(variance)
numeric_col = tf1.feature_column.numeric_column('col', normalizer_fn=normalize)
call_feature_columns(numeric_col, {'col': tf.constant([[0.], [1.], [2.]])})
<tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[-2.], [-1.], [ 0.]], dtype=float32)>
対照的に、Keras では、この正規化は tf.keras.layers.Normalization
で実行できます。
normalization_layer = tf.keras.layers.Normalization(mean=2.0, variance=1.0)
normalization_layer(tf.constant([[0.], [1.], [2.]]))
<tf.Tensor: shape=(3, 1), dtype=float32, numpy= array([[-2.], [-1.], [ 0.]], dtype=float32)>
数値特徴量のバケット化と One-hot エンコーディング
連続する浮動小数点の入力を変換するもう 1 つの一般的な方法は、固定範囲の整数にバケット化することです。
特徴量カラムでは、tf.feature_column.bucketized_column
を使用します。
numeric_col = tf1.feature_column.numeric_column('col')
bucketized_col = tf1.feature_column.bucketized_column(numeric_col, [1, 4, 5])
call_feature_columns(bucketized_col, {'col': tf.constant([1., 2., 3., 4., 5.])})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/3043215186.py:2: bucketized_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. <tf.Tensor: shape=(5, 4), dtype=float32, numpy= array([[0., 1., 0., 0.], [0., 1., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]], dtype=float32)>
Keras では、これを tf.keras.layers.Discretization
に置き換えます。
discretization_layer = tf.keras.layers.Discretization(bin_boundaries=[1, 4, 5])
one_hot_layer = tf.keras.layers.CategoryEncoding(
num_tokens=4, output_mode='one_hot')
one_hot_layer(discretization_layer([1., 2., 3., 4., 5.]))
<tf.Tensor: shape=(5, 4), dtype=float32, numpy= array([[0., 1., 0., 0.], [0., 1., 0., 0.], [0., 1., 0., 0.], [0., 0., 1., 0.], [0., 0., 0., 1.]], dtype=float32)>
語彙を使用した文字列データの One-hot エンコーディング
文字列の特徴量を処理するには、多くの場合、文字列をインデックスに変換するために語彙の検索が必要です。特徴量カラムを使用して文字列を検索し、インデックスを One-hot エンコードする例を次に示します。
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'sizes',
vocabulary_list=['small', 'medium', 'large'],
num_oov_buckets=0)
indicator_col = tf1.feature_column.indicator_column(vocab_col)
call_feature_columns(indicator_col, {'sizes': ['small', 'medium', 'large']})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/2845961037.py:1: categorical_column_with_vocabulary_list (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. <tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)>
Keras 前処理レイヤーを使用して、output_mode
を 'one_hot'
に設定して tf.keras.layers.StringLookup
レイヤーを使用します。
string_lookup_layer = tf.keras.layers.StringLookup(
vocabulary=['small', 'medium', 'large'],
num_oov_indices=0,
output_mode='one_hot')
string_lookup_layer(['small', 'medium', 'large'])
<tf.Tensor: shape=(3, 3), dtype=float32, numpy= array([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]], dtype=float32)>
注意: 大規模な One-hot エンコーディングの場合、出力のスパース表現を使用する方がはるかに効率的です。sparse=True
を StringLookup
レイヤーに渡すと、レイヤーの出力は tf.sparse.SparseTensor
になり、効率的に tf.keras.layers.Dense
レイヤーへの入力として処理されます。
語彙を使用した文字列データの埋め込み
より大きな語彙の場合、パフォーマンスを向上させるために埋め込みが必要になることがよくあります。特徴量カラムを使用して文字列特徴量を埋め込む例を次に示します。
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'col',
vocabulary_list=['small', 'medium', 'large'],
num_oov_buckets=0)
embedding_col = tf1.feature_column.embedding_column(vocab_col, 4)
call_feature_columns(embedding_col, {'col': ['small', 'medium', 'large']})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/999372599.py:5: embedding_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. <tf.Tensor: shape=(3, 4), dtype=float32, numpy= array([[-0.22725259, 0.00829591, 0.08101071, 0.94072044], [-0.2397708 , 0.47947466, 0.4402268 , 0.5795534 ], [-0.9305114 , -0.02075406, -0.60498226, -0.10655837]], dtype=float32)>
これは、Keras 前処理レイヤーを使用して、tf.keras.layers.StringLookup
レイヤーと tf.keras.layers.Embedding
レイヤーを組み合わせることで実現できます。StringLookup
のデフォルトの出力は、埋め込みに直接入力できる整数インデックスになります。
注意: Embedding
レイヤーには、トレーニング可能なパラメータが含まれています。StringLookup
レイヤーはモデルの内部または外部のデータに適用できますが、正しくトレーニングするには、Embedding
が常にトレーニング可能な Keras モデルの一部である必要があります。
string_lookup_layer = tf.keras.layers.StringLookup(
vocabulary=['small', 'medium', 'large'], num_oov_indices=0)
embedding = tf.keras.layers.Embedding(3, 4)
embedding(string_lookup_layer(['small', 'medium', 'large']))
<tf.Tensor: shape=(3, 4), dtype=float32, numpy= array([[ 0.0173264 , -0.00587343, -0.00816015, 0.02549114], [ 0.01286792, -0.00867562, -0.04363815, 0.03605429], [ 0.04916598, 0.04651806, 0.03595657, -0.00171267]], dtype=float32)>
重み付きカテゴリカルデータの和
場合によっては、重みが関連付けられているカテゴリが出現するたびにカテゴリカルデータを処理する必要があります。特徴量カラムでは、これは tf.feature_column.weighted_categorical_column
で処理されます。indicator_column
と組み合わせると、カテゴリごとの重みの和を計算できます。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
categorical_col = tf1.feature_column.categorical_column_with_identity(
'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
categorical_col, 'weights')
indicator_col = tf1.feature_column.indicator_column(weighted_categorical_col)
call_feature_columns(indicator_col, {'ids': ids, 'weights': weights})
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/3529191023.py:6: weighted_categorical_column (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version. Instructions for updating: Use Keras preprocessing layers instead, either directly or via the `tf.keras.utils.FeatureSpace` utility. Each of `tf.feature_column.*` has a functional equivalent in `tf.keras.layers` for feature preprocessing when training a Keras model. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/feature_column/feature_column_v2.py:4033: sparse_merge (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version. Instructions for updating: No similar op available at this time. <tf.Tensor: shape=(1, 20), dtype=float32, numpy= array([[0. , 0. , 0. , 0. , 0. , 1.2, 0. , 0. , 0. , 0. , 0. , 1.5, 0. , 0. , 0. , 0. , 0. , 2. , 0. , 0. ]], dtype=float32)>
Keras では、これは output_mode='count'
で count_weights
入力を tf.keras.layers.CategoryEncoding
に渡すことで実行できます。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
# Using sparse output is more efficient when `num_tokens` is large.
count_layer = tf.keras.layers.CategoryEncoding(
num_tokens=20, output_mode='count', sparse=True)
tf.sparse.to_dense(count_layer(ids, count_weights=weights))
<tf.Tensor: shape=(1, 20), dtype=float32, numpy= array([[0. , 0. , 0. , 0. , 0. , 1.2, 0. , 0. , 0. , 0. , 0. , 1.5, 0. , 0. , 0. , 0. , 0. , 2. , 0. , 0. ]], dtype=float32)>
重み付きカテゴリカルデータの埋め込み
または、重み付きカテゴリカル入力を埋め込みたい場合もあります。特徴量カラムでは、embedding_column
に combiner
引数が含まれています。サンプルにカテゴリの複数のエントリが含まれている場合、それらは引数の設定(デフォルトでは 'mean'
)に従って結合されます。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
categorical_col = tf1.feature_column.categorical_column_with_identity(
'ids', num_buckets=20)
weighted_categorical_col = tf1.feature_column.weighted_categorical_column(
categorical_col, 'weights')
embedding_col = tf1.feature_column.embedding_column(
weighted_categorical_col, 4, combiner='mean')
call_feature_columns(embedding_col, {'ids': ids, 'weights': weights})
<tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[ 0.01570974, -0.12220951, -0.4174866 , -0.16703086]], dtype=float32)>
Keras では、tf.keras.layers.Embedding
に対する combiner
オプションはありませんが、tf.keras.layers.Dense
で同じ効果を実現できます。上記の embedding_column
は、カテゴリの重みに従って埋め込みベクトルを単純に線形結合しています。一見明らかではありませんが、カテゴリカル入力をサイズ (num_tokens)
の疎な重みベクトルとして表し、形状 (embedding_size, num_tokens)
の Dense
カーネルを掛けるのとまったく同じです。
ids = tf.constant([[5, 11, 5, 17, 17]])
weights = tf.constant([[0.5, 1.5, 0.7, 1.8, 0.2]])
# For `combiner='mean'`, normalize your weights to sum to 1. Removing this line
# would be equivalent to an `embedding_column` with `combiner='sum'`.
weights = weights / tf.reduce_sum(weights, axis=-1, keepdims=True)
count_layer = tf.keras.layers.CategoryEncoding(
num_tokens=20, output_mode='count', sparse=True)
embedding_layer = tf.keras.layers.Dense(4, use_bias=False)
embedding_layer(count_layer(ids, count_weights=weights))
<tf.Tensor: shape=(1, 4), dtype=float32, numpy= array([[-0.20355289, 0.20143004, 0.03062132, 0.2749697 ]], dtype=float32)>
完全なトレーニングサンプル
完全なトレーニングワークフローでは、まず、異なる型の 3 つの特徴量を含むいくつかのデータを準備します。
features = {
'type': [0, 1, 1],
'size': ['small', 'small', 'medium'],
'weight': [2.7, 1.8, 1.6],
}
labels = [1, 1, 0]
predict_features = {'type': [0], 'size': ['foo'], 'weight': [-0.7]}
TensorFlow 1 と TensorFlow 2 の両方のワークフローに共通する定数をいくつか定義します。
vocab = ['small', 'medium', 'large']
one_hot_dims = 3
embedding_dims = 4
weight_mean = 2.0
weight_variance = 1.0
特徴量カラムを使用する
特徴量カラムは、作成時に Estimator にリストとして渡す必要があり、トレーニング中に暗黙的に呼び出されます。
categorical_col = tf1.feature_column.categorical_column_with_identity(
'type', num_buckets=one_hot_dims)
# Convert index to one-hot; e.g. [2] -> [0,0,1].
indicator_col = tf1.feature_column.indicator_column(categorical_col)
# Convert strings to indices; e.g. ['small'] -> [1].
vocab_col = tf1.feature_column.categorical_column_with_vocabulary_list(
'size', vocabulary_list=vocab, num_oov_buckets=1)
# Embed the indices.
embedding_col = tf1.feature_column.embedding_column(vocab_col, embedding_dims)
normalizer_fn = lambda x: (x - weight_mean) / math.sqrt(weight_variance)
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
numeric_col = tf1.feature_column.numeric_column(
'weight', normalizer_fn=normalizer_fn)
estimator = tf1.estimator.DNNClassifier(
feature_columns=[indicator_col, embedding_col, numeric_col],
hidden_units=[1])
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
estimator.train(_input_fn)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_36989/1892339471.py:17: DNNClassifier.__init__ (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:807: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpf3rx3u6r INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpf3rx3u6r', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/dnn.py:446: dnn_logit_fn_builder (from tensorflow_estimator.python.estimator.canned.dnn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/model_fn.py:250: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/basic_session_run_hooks.py:232: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. 2024-01-11 18:05:53.152605: W tensorflow/core/common_runtime/type_inference.cc:339] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT64 } } } is neither a subtype nor a supertype of the combined inputs preceding it: type_id: TFT_OPTIONAL args { type_id: TFT_PRODUCT args { type_id: TFT_TENSOR args { type_id: TFT_INT32 } } } for Tuple type infernce function 0 while inferring type of node 'dnn/zero_fraction/cond/output/_18' INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpf3rx3u6r/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 1.0011392, step = 0 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 3... INFO:tensorflow:Saving checkpoints for 3 into /tmpfs/tmp/tmpf3rx3u6r/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 3... INFO:tensorflow:Loss for final step: 0.73061395. <tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f0bb025c0d0>
また、特徴量カラムは、モデルで推論を実行するときに入力データを変換するためにも使用されます。
def _predict_fn():
return tf1.data.Dataset.from_tensor_slices(predict_features).batch(1)
next(estimator.predict(_predict_fn))
INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:596: ClassificationOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:1307: RegressionOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/canned/head.py:1309: PredictOutput.__init__ (from tensorflow.python.saved_model.model_utils.export_output) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpf3rx3u6r/model.ckpt-3 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. {'logits': array([-2.0540094], dtype=float32), 'logistic': array([0.11364788], dtype=float32), 'probabilities': array([0.88635206, 0.11364787], dtype=float32), 'class_ids': array([0]), 'classes': array([b'0'], dtype=object), 'all_class_ids': array([0, 1], dtype=int32), 'all_classes': array([b'0', b'1'], dtype=object)}
Keras 前処理レイヤーを使用する
Keras の前処理レイヤーは、より柔軟に呼び出せます。レイヤーはテンソルに直接適用したり、tf.data
入力パイプライン内で使用したり、トレーニング可能な Keras モデルに直接構築したりできます。
この例では、tf.data
入力パイプライン内に前処理レイヤーを適用します。これを行うには、別の tf.keras.Model
を定義して、入力する特徴量を前処理します。このモデルはトレーニング可能ではありませんが、前処理レイヤーをグループ化する便利な方法です。
inputs = {
'type': tf.keras.Input(shape=(), dtype='int64'),
'size': tf.keras.Input(shape=(), dtype='string'),
'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Convert index to one-hot; e.g. [2] -> [0,0,1].
type_output = tf.keras.layers.CategoryEncoding(
one_hot_dims, output_mode='one_hot')(inputs['type'])
# Convert size strings to indices; e.g. ['small'] -> [1].
size_output = tf.keras.layers.StringLookup(vocabulary=vocab)(inputs['size'])
# Normalize the numeric inputs; e.g. [2.0] -> [0.0].
weight_output = tf.keras.layers.Normalization(
axis=None, mean=weight_mean, variance=weight_variance)(inputs['weight'])
outputs = {
'type': type_output,
'size': size_output,
'weight': weight_output,
}
preprocessing_model = tf.keras.Model(inputs, outputs)
注意: レイヤー作成時に語彙と正規化統計を提供する代わりに、多くの前処理レイヤーは、入力データからレイヤーの状態を直接学習するための adapt()
メソッドを提供します。詳細については、前処理ガイドを参照してください。
tf.data.Dataset.map
への呼び出し内でこのモデルを適用できるようになりました。map
に渡される関数は自動的に tf.function
に変換され、tf.function
コードを記述する際の通常の注意事項が適用されることに注意してください(副作用はありません)。
# Apply the preprocessing in tf.data.Dataset.map.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
dataset = dataset.map(lambda x, y: (preprocessing_model(x), y),
num_parallel_calls=tf.data.AUTOTUNE)
# Display a preprocessed input sample.
next(dataset.take(1).as_numpy_iterator())
({'type': array([[1., 0., 0.]], dtype=float32), 'size': array([1]), 'weight': array([0.70000005], dtype=float32)}, array([1], dtype=int32))
次に、トレーニング可能なレイヤーを含む別の Model
を定義します。このモデルへの入力が、前処理された特徴量の型と形状をどのように反映しているかに注目してください。
inputs = {
'type': tf.keras.Input(shape=(one_hot_dims,), dtype='float32'),
'size': tf.keras.Input(shape=(), dtype='int64'),
'weight': tf.keras.Input(shape=(), dtype='float32'),
}
# Since the embedding is trainable, it needs to be part of the training model.
embedding = tf.keras.layers.Embedding(len(vocab), embedding_dims)
outputs = tf.keras.layers.Concatenate()([
inputs['type'],
embedding(inputs['size']),
tf.expand_dims(inputs['weight'], -1),
])
outputs = tf.keras.layers.Dense(1)(outputs)
training_model = tf.keras.Model(inputs, outputs)
training_model
を tf.keras.Model.fit
でトレーニングできるようになりました。
# Train on the preprocessed data.
training_model.compile(
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True))
training_model.fit(dataset)
3/3 [==============================] - 1s 5ms/step - loss: 0.7808 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1704996357.417202 37160 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. <keras.src.callbacks.History at 0x7f0b682012e0>
最後に、推論時に、これらの個別の段階を組み合わせて、生の特徴量入力を処理する単一のモデルにすると便利です。
inputs = preprocessing_model.input
outputs = training_model(preprocessing_model(inputs))
inference_model = tf.keras.Model(inputs, outputs)
predict_dataset = tf.data.Dataset.from_tensor_slices(predict_features).batch(1)
inference_model.predict(predict_dataset)
1/1 [==============================] - 0s 103ms/step array([[1.0852278]], dtype=float32)
この合成モデルは、後で使用するために .keras
ファイルとして保存できます。
inference_model.save('model.keras')
restored_model = tf.keras.models.load_model('model.keras')
restored_model.predict(predict_dataset)
1/1 [==============================] - 0s 80ms/step array([[1.0852278]], dtype=float32)
注意: 前処理レイヤーはトレーニングできないため、tf.data
を使用して非同期で適用できます。これには、前処理されたバッチをプリフェッチし、アクセラレータを解放してモデルの微分可能な部分に集中できるため、パフォーマンス上の利点があります(詳細については、tf.data
API によるパフォーマンスの向上ガイドのプリフェッチセクションを参照してください)。このガイドが示すように、トレーニング中に前処理を分離し、推論中にそれを構成することは、これらのパフォーマンスの向上を活用する柔軟な方法です。ただし、モデルが小さい場合や前処理時間を無視できる場合は、最初から完全なモデルに前処理を組み込む方が簡単な場合があります。これを行うには、tf.keras.Input
で始まる単一のモデルを構築し、その後に前処理レイヤー、その後にトレーニング可能なレイヤーを構築します。
特徴量カラムに対応する Keras レイヤー
参考までに、特徴量カラムにほぼ対応する Keras 前処理レイヤーを次に示します。
† tf.keras.layers.TextVectorization
は、自由形式のテキスト入力 (文全体または段落全体など)を直接処理できます。これは、TensorFlow 1 でのカテゴリカルシーケンス処理の 1 対 1 の置き換えではありませんが、アドホックテキスト前処理の便利な置き換えを提供します。
† tf.keras.layers.TextVectorization
は、自由形式のテキスト入力 (文全体または段落全体など)を直接処理できます。これは、TensorFlow 1 でのカテゴリカルシーケンス処理の 1 対 1 の置き換えではありませんが、アドホックテキスト前処理の便利な置き換えを提供します。
注意: tf.estimator.LinearClassifier
などの線形 Estimator は、embedding_column
または indicator_column
なしで直接のカテゴリカル入力(整数インデックス)を処理できます。ただし、整数インデックスを tf.keras.layers.Dense
または tf.keras.experimental.LinearModel
に直接渡すことはできません。これらの入力は、 Dense
または LinearModel
を呼び出す前に最初に tf.layers.CategoryEncoding
で output_mode='count'
(カテゴリサイズが大きい場合は sparse=True
)でエンコードする必要があります)。
次のステップ
- Keras 前処理レイヤーの詳細については、前処理レイヤーの操作ガイドを参照してください。
- 前処理レイヤーを構造化データに適用する詳細な例については、Keras 前処理レイヤーを使用して構造化データを分類するチュートリアルを参照してください。