TensorFlow.org で表示 | Google Colab で実行 | GitHub でソースを表示 | ノートブックをダウンロード |
概要
このチュートリアルでは、トレーニング中またはトレーニング後に tf.distribute.Strategy
を使用して SavedModel 形式でモデルを保存して読み込む方法を説明します。Keras モデルの保存と読み込みには、高レベル(tf.keras.Model.save
と tf.keras.models.load_model
)と低レベル(tf.saved_model.save
と tf.saved_model.load
)の 2 種類の API があります。
SavedModel とシリアル化の全般的な内容については、SavedModel ガイドと Keras モデルのシリアル化ガイドをお読みください。では、単純な例から始めましょう。
注意: TensorFlow モデルはコードであるため、信頼できないコードには注意する必要があります。詳細は、TensorFlow を安全に使用するをご覧ください。
依存関係をインポートします。
import tensorflow_datasets as tfds
import tensorflow as tf
2024-01-11 18:14:24.741176: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-01-11 18:14:24.741223: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-01-11 18:14:24.742716: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow Datasets と tf.data
でデータを読み込んで準備し、tf.distribute.MirroredStrategy
を使ってモデルを作成します。
mirrored_strategy = tf.distribute.MirroredStrategy()
def get_data():
datasets = tfds.load(name='mnist', as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)
return train_dataset, eval_dataset
def get_model():
with mirrored_strategy.scope():
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
tf.keras.Model.fit
を使用してモデルをトレーニングします。
model = get_model()
train_dataset, eval_dataset = get_data()
model.fit(train_dataset, epochs=2)
2024-01-11 18:14:31.467356: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/2 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1704996877.849655 50970 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 235/235 [==============================] - ETA: 0s - loss: 0.3291 - sparse_categorical_accuracy: 0.9094INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',). 235/235 [==============================] - 8s 8ms/step - loss: 0.3291 - sparse_categorical_accuracy: 0.9094 Epoch 2/2 235/235 [==============================] - 2s 7ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9685 <keras.src.callbacks.History at 0x7f7e10386cd0>
モデルを保存して読み込む
作業に使用する単純なモデルを準備できたので、保存と読み込みに使用する API を見てみましょう。使用できる API には、以下の 2 種類があります。
- 高レベル(Keras):
Model.save
およびtf.keras.models.load_model
(.keras
zip アーカイブ形式) - 低レベル:
tf.saved_model.save
およびtf.saved_model.load
(TF SavedModel 形式)
Keras API
Keras API を使用したモデルの保存と読み込みの例を以下に示します。
keras_model_path = '/tmp/keras_save.keras'
model.save(keras_model_path)
tf.distribute.Strategy
を使用せずにモデルを復元します。
restored_keras_model = tf.keras.models.load_model(keras_model_path)
restored_keras_model.fit(train_dataset, epochs=2)
Epoch 1/2 235/235 [==============================] - 2s 4ms/step - loss: 0.0686 - sparse_categorical_accuracy: 0.9798 Epoch 2/2 235/235 [==============================] - 1s 4ms/step - loss: 0.0540 - sparse_categorical_accuracy: 0.9844 <keras.src.callbacks.History at 0x7f7f5048b670>
モデルを復元したら、Model.compile
をもう一度呼び出さずにそのままトレーニングを続行できます。これは、保存前にすでにコンパイル済みであるためです。このモデルは、Keras zip アーカイブ形式で保存されており、.keras
拡張子で識別できます。詳細については、Keras の保存に関するガイドをご覧ください。
次に、tf.distribute.Strategy
を使用してモデルを復元し、トレーニングします。
another_strategy = tf.distribute.OneDeviceStrategy('/cpu:0')
with another_strategy.scope():
restored_keras_model_ds = tf.keras.models.load_model(keras_model_path)
restored_keras_model_ds.fit(train_dataset, epochs=2)
Epoch 1/2 2024-01-11 18:14:45.633878: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. 2024-01-11 18:14:45.694588: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations. 15/235 [>.............................] - ETA: 2s - loss: 0.0826 - sparse_categorical_accuracy: 0.9776 2024-01-11 18:14:46.266915: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream. 2024-01-11 18:14:46.267137: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream. 2024-01-11 18:14:46.301591: E external/local_xla/xla/stream_executor/stream_executor_internal.h:177] SetPriority unimplemented for this stream. 235/235 [==============================] - 3s 11ms/step - loss: 0.0699 - sparse_categorical_accuracy: 0.9796 Epoch 2/2 235/235 [==============================] - 3s 11ms/step - loss: 0.0531 - sparse_categorical_accuracy: 0.9844
Model.fit
出力からわかるように、tf.distribute.Strategy
を使って期待どおり読み込まれました。ここで使用されるストラテジーは、保存前と同じストラテジーである必要はありません。
tf.saved_model
API
より低レベルの API を使用したモデルの保存方法は、Keras API を使う方法に似ています。
model = get_model() # get a fresh model
saved_model_path = '/tmp/tf_save'
tf.saved_model.save(model, saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets
読み込みは、tf.saved_model.load
を使用して行えますが、これは低レベル API(したがって、より幅広いユースケースのある API)であるため、Keras モデルを返しません。代わりに、推論を行うために使用できる関数を含むオブジェクトを返します。以下に例を示します。
DEFAULT_FUNCTION_KEY = 'serving_default'
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
読み込まれたオブジェクトには、それぞれにキーが関連付けられた複数の関数が含まれている可能性があります。"serving_default"
キーは、保存された Keras モデルを使用した推論関数のデフォルトのキーです。この関数で推論するには、以下のようにします。
predict_dataset = eval_dataset.map(lambda image, label: image)
for batch in predict_dataset.take(1):
print(inference_func(batch))
{'dense_3': <tf.Tensor: shape=(256, 10), dtype=float32, numpy= array([[ 0.04390968, 0.30341768, 0.05374109, ..., -0.35343656, 0.03065785, -0.00975093], [-0.04910231, 0.16482985, 0.06436244, ..., -0.27770516, 0.02216907, 0.13293922], [-0.05661844, 0.2683993 , -0.06041192, ..., -0.26340052, 0.02152548, 0.10264045], ..., [-0.12805948, 0.11079367, -0.10359426, ..., -0.26105058, 0.0311166 , 0.02954188], [-0.11231118, 0.22162321, 0.04027553, ..., -0.34616578, 0.02095792, 0.01622906], [-0.07966347, 0.08217648, -0.14690818, ..., -0.21150741, 0.03090278, -0.12792973]], dtype=float32)>} 2024-01-11 18:14:52.675006: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
また、分散方法で読み込んで推論を実行することもできます。
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
inference_func = loaded.signatures[DEFAULT_FUNCTION_KEY]
dist_predict_dataset = another_strategy.experimental_distribute_dataset(
predict_dataset)
# Calling the function in a distributed manner
for batch in dist_predict_dataset:
result = another_strategy.run(inference_func, args=(batch,))
print(result)
break
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') 2024-01-11 18:14:52.889461: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. WARNING:tensorflow:Using MirroredStrategy eagerly has significant overhead currently. We will be working on improving this in the future, but for now please wrap `call_for_each_replica` or `experimental_run` or `run` inside a tf.function to get the best performance. {'dense_3': PerReplica:{ 0: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[ 0.04390973, 0.30341768, 0.05374111, -0.02709487, -0.28792804, -0.19333729, 0.325674 , -0.3534366 , 0.03065785, -0.00975094], [-0.04910226, 0.1648299 , 0.06436247, 0.0941631 , -0.36874533, -0.20513675, 0.09101719, -0.27770516, 0.02216907, 0.1329391 ], [-0.05661836, 0.26839924, -0.0604119 , 0.04293879, -0.37735796, -0.01866844, 0.23681116, -0.26340055, 0.02152544, 0.10264044], [ 0.10640895, 0.1561212 , 0.06909597, -0.06987031, -0.1984469 , -0.04289627, 0.18389529, -0.18640813, 0.06818488, -0.11756891], [ 0.01074794, 0.21058783, -0.13376951, -0.07198893, -0.34633294, -0.0951823 , 0.09859037, -0.17102982, 0.00822654, 0.02078733], [ 0.02912218, 0.19898024, -0.2194208 , -0.09297523, -0.22816458, -0.14823863, 0.1251952 , -0.22406554, 0.07672149, -0.06627873], [-0.07373619, 0.1642075 , 0.07785851, 0.0902054 , -0.38338202, -0.16163939, 0.15576416, -0.2677853 , 0.03417471, 0.14622498], [-0.01103535, 0.13357913, -0.11604594, -0.1436121 , -0.11463816, -0.09246697, 0.02174798, -0.13517785, 0.11130458, -0.16181585], [-0.05913471, 0.25684866, -0.08185954, 0.11305049, -0.41555196, -0.11584461, 0.17271446, -0.28591448, 0.0171692 , -0.03395621], [ 0.08944317, 0.0613468 , -0.1547694 , -0.04008061, -0.12746565, -0.06679891, 0.08744669, -0.14509115, -0.00697993, -0.20970711], [-0.02094699, 0.1820338 , -0.00243831, -0.01445665, -0.30628368, -0.20580608, 0.17843154, -0.25234053, -0.02907382, -0.03199375], [ 0.01979519, 0.2740837 , -0.01698644, -0.18219724, -0.40358156, -0.21365486, 0.04891555, -0.29936785, 0.06018814, -0.07291923], [-0.10811754, 0.13060504, -0.03464153, 0.06798229, -0.43056038, 0.01354938, 0.21008897, -0.31439742, -0.08188079, 0.11052423], [ 0.05240241, 0.17336613, -0.17786968, -0.08816151, -0.34909615, -0.11435185, 0.07548714, -0.27006733, 0.09544405, -0.20011938], [ 0.15914436, 0.24435948, -0.04742448, -0.02628867, -0.30243245, -0.13071185, 0.19322398, -0.34721792, 0.08353946, -0.16963294], [ 0.00991537, 0.2729301 , -0.10826392, -0.18254003, -0.303356 , -0.2544545 , 0.09932231, -0.23833567, 0.09032176, 0.1128719 ], [ 0.12592334, 0.10697531, -0.03352933, -0.18785489, -0.08577707, -0.20832878, 0.13447234, -0.06923658, 0.01704483, -0.13321611], [-0.05993862, 0.05663422, -0.05601699, -0.11318138, -0.32292652, -0.13415524, 0.07078642, -0.24608482, -0.03674737, -0.17912035], [-0.08632107, 0.1418068 , -0.13200448, 0.01969666, -0.25682348, -0.14198145, 0.10951944, -0.27877283, 0.02858731, -0.0374111 ], [ 0.11897041, 0.2594269 , 0.03674663, -0.11091953, -0.30228472, -0.11023479, 0.1663589 , -0.376381 , 0.06430978, -0.16984001], [-0.00289295, 0.04047507, -0.11863903, -0.05876009, -0.23271626, -0.13283929, 0.12130734, -0.13938479, -0.07576559, 0.01863652], [ 0.01375921, -0.00918115, -0.09295098, -0.06612265, -0.159753 , -0.12676233, 0.0832831 , -0.20027536, -0.06356223, -0.1602694 ], [-0.11601318, 0.05592554, -0.00799035, -0.03840588, -0.31699634, -0.12557843, 0.00217704, -0.13452399, -0.04661082, -0.00160295], [-0.04387126, 0.02643408, -0.08738586, 0.0719263 , -0.27419144, -0.10074591, -0.00859936, -0.1349473 , -0.02253597, 0.09461489], [ 0.00437208, 0.09436843, -0.05663288, -0.00419031, -0.22780542, -0.1780376 , 0.1348795 , -0.1746904 , -0.1173916 , -0.06692152], [ 0.03692336, 0.20061986, 0.07223521, -0.00423313, -0.36737412, -0.11080196, 0.1832721 , -0.31301066, -0.0066702 , -0.13494146], [ 0.05483724, 0.03380407, -0.04579096, -0.0625519 , -0.19104283, -0.12388766, 0.04404955, -0.1911788 , -0.02904271, -0.12793231], [-0.03292175, 0.0789917 , -0.20172037, -0.05561272, -0.17042139, -0.29631954, 0.15214083, -0.10865977, -0.04539655, -0.09254031], [ 0.05429688, 0.24875508, -0.25274026, -0.08207695, -0.250165 , -0.07061552, 0.03925603, -0.1252473 , 0.11538928, 0.02362049], [-0.01455817, 0.08509667, -0.02281824, 0.09500916, -0.20391591, -0.06684849, 0.23080888, -0.33867043, 0.0283998 , 0.05314478], [-0.00075892, 0.15760201, 0.07448781, -0.02245309, -0.27052298, -0.02893338, 0.24336715, -0.29186094, -0.0194285 , -0.04287457], [-0.07933994, 0.1979461 , 0.0075167 , 0.02225138, -0.36191055, -0.01780383, 0.10470055, -0.2530624 , 0.044648 , -0.06972083], [-0.00652711, 0.17066438, -0.03553644, -0.07859261, -0.4780447 , -0.11430528, 0.17088383, -0.35727382, 0.01104823, -0.02751846], [ 0.01646123, 0.02834756, -0.15534458, -0.04221815, -0.17725925, -0.1818274 , 0.0135299 , -0.20522466, -0.03495998, -0.04329636], [-0.17032023, 0.1409552 , -0.14403607, -0.03054993, -0.40012705, -0.09377775, 0.00878634, -0.12934831, 0.03538194, 0.02580342], [-0.1051839 , 0.24074066, -0.08009566, -0.03220345, -0.39412522, -0.09787439, 0.25904325, -0.29801127, 0.029187 , 0.14273585], [-0.02889195, 0.27769655, 0.14630178, 0.00704605, -0.42074952, -0.09446865, 0.22208358, -0.2960379 , -0.10665543, 0.00482226], [ 0.108211 , 0.1608254 , -0.12860669, -0.15433891, -0.16320105, -0.29223463, 0.14248174, -0.1457768 , 0.076525 , -0.10444155], [ 0.05384266, 0.02627722, 0.05595293, -0.05162864, -0.09512395, 0.02082699, 0.14003542, -0.21606655, -0.02665446, -0.15942779], [ 0.08897467, 0.12156868, 0.0072432 , -0.03514582, -0.2993672 , -0.16040912, 0.13929161, -0.25935876, 0.03195405, -0.17140916], [-0.05919135, 0.10460456, -0.0541361 , -0.05133465, -0.29787263, -0.01104981, 0.09286734, -0.18553476, 0.03485874, 0.02347505], [ 0.10552 , 0.13782081, -0.01198097, 0.02319556, -0.23136751, -0.20539609, 0.30999058, -0.33470595, -0.05787981, -0.03471535], [ 0.01394254, 0.04610374, -0.1555309 , 0.02138674, -0.15324359, -0.13208178, 0.13295804, -0.31510523, -0.04771263, -0.12409957], [ 0.08796158, 0.2096323 , 0.03023741, -0.06917568, -0.18738158, -0.04232989, 0.24464765, -0.32475582, -0.0190682 , -0.12991205], [-0.02886308, 0.1945777 , -0.19298053, -0.09160493, -0.32743698, -0.11305106, 0.08422519, -0.17924672, 0.06171962, 0.02897899], [ 0.12523907, 0.19837332, 0.08683064, 0.08463505, -0.28432304, -0.16086362, 0.2658471 , -0.3375882 , -0.00523884, -0.11530196], [ 0.02348459, 0.06094605, -0.20900917, 0.08927577, -0.20602939, -0.09567806, 0.16529328, -0.22091651, 0.0310331 , -0.01366656], [-0.03577258, 0.16763109, -0.00228991, -0.03373875, -0.43727922, -0.23163418, 0.15620582, -0.30316573, 0.02358727, 0.02402775], [-0.01045046, 0.23059595, -0.10650987, 0.04214311, -0.35271657, -0.15367958, 0.18296245, -0.3670715 , 0.00242524, -0.01590508], [ 0.11234806, 0.127929 , -0.1416234 , 0.01856027, -0.1370542 , -0.07529534, 0.15012704, -0.11457616, 0.05310739, -0.23884141], [-0.12933847, 0.18685465, -0.0281913 , 0.03565511, -0.31908727, -0.02797761, 0.01718334, -0.21260841, 0.08538416, 0.10911464], [ 0.02774051, 0.14014125, -0.11116177, 0.06069104, -0.2972231 , -0.08201168, 0.18919642, -0.2732424 , -0.07256175, -0.04372493], [-0.05285444, 0.09731005, -0.16486159, 0.14584875, -0.24468833, -0.11813442, 0.1891388 , -0.27349263, 0.02680733, 0.02208898], [ 0.01703629, 0.14432956, -0.07905609, -0.05359066, -0.2380548 , -0.07374218, 0.18941419, -0.2671988 , -0.04146045, -0.01040378], [-0.01229454, 0.2238013 , -0.07782885, -0.05288902, -0.34084964, -0.07656305, 0.25703445, -0.38321817, 0.02506851, 0.0528542 ], [-0.09558831, 0.05508473, -0.20615673, -0.07739481, -0.35988048, -0.23042265, 0.05733413, -0.20256129, -0.04815403, -0.07377221], [-0.12882753, 0.13142166, -0.13448793, 0.07987095, -0.35866457, -0.03530126, 0.09230228, -0.17570223, 0.09357665, 0.06170238], [-0.01151971, 0.20554355, -0.04922439, -0.08847281, -0.3856749 , -0.0984439 , 0.13491108, -0.3456065 , 0.08735792, -0.03551546], [ 0.14842089, 0.34935832, 0.09578269, 0.01213365, -0.28933874, -0.07850939, 0.19179463, -0.3372814 , 0.12811774, -0.01675376], [-0.18260533, 0.04856708, -0.11119141, -0.07749753, -0.3635567 , -0.23310214, 0.13667285, -0.14879341, -0.14041449, 0.07667582], [-0.17618516, 0.1462372 , -0.18028377, 0.06713124, -0.38335955, -0.03557798, 0.10410185, -0.17592236, -0.01435028, 0.15338154], [ 0.04055981, 0.24520999, -0.17271513, -0.14913839, -0.3042269 , -0.12039592, 0.2689139 , -0.32811943, -0.00502907, 0.00336872], [ 0.03560546, 0.1754866 , -0.11660206, -0.03360491, -0.27283472, -0.18731177, 0.08661997, -0.09050013, 0.04968301, -0.07649884], [-0.02088141, 0.13798103, -0.18891963, 0.00372715, -0.16121043, -0.1467973 , 0.05670452, -0.16789939, 0.09856127, 0.01790601]], dtype=float32)>, 1: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[-0.09833355, 0.15665436, -0.07166141, 0.01613978, -0.3543378 , -0.08038074, 0.2748848 , -0.28265885, -0.04764143, 0.09244367], [-0.10978055, 0.0580076 , -0.16466689, -0.07054193, -0.37483492, -0.22544345, -0.01834334, -0.23438914, -0.05652276, -0.00428347], [ 0.06102126, 0.23393367, -0.12214032, -0.1585012 , -0.47405833, -0.18904763, 0.14417635, -0.31168947, 0.07823397, -0.17237163], [ 0.13491186, -0.09206396, -0.17387056, -0.02996402, -0.07794836, -0.10825613, 0.09467977, -0.12667504, -0.10682128, -0.2489583 ], [-0.03715318, 0.09756672, -0.13965347, -0.1194995 , -0.2671111 , -0.21588892, 0.20645793, -0.16797486, -0.0715595 , 0.10796719], [-0.02775428, 0.06620544, -0.07389922, -0.08922986, -0.24126638, -0.19932179, 0.09457048, -0.27572516, 0.00439627, -0.13424556], [-0.13065772, 0.04677206, -0.11908375, -0.03384325, -0.29424554, -0.08093098, 0.04305109, -0.12765141, -0.09444566, 0.03544688], [ 0.13585913, 0.06523642, -0.05614541, -0.17160474, -0.16298045, -0.24585429, 0.2457303 , -0.1337391 , -0.03644121, -0.06396304], [-0.00827954, 0.17156884, -0.08429323, -0.0831846 , -0.29328197, -0.13459298, 0.11706164, -0.19470227, 0.06820276, -0.13680328], [-0.08339091, 0.2388103 , -0.11374965, 0.10000415, -0.40983635, -0.10774978, 0.24996822, -0.33086625, -0.05999116, 0.12141761], [ 0.09205646, -0.01435722, -0.17090788, -0.05419794, -0.07356755, -0.03496896, 0.0213257 , -0.11919186, 0.00437786, -0.21572636], [ 0.02209954, 0.11327548, -0.08586521, -0.10829636, -0.1623912 , -0.14285421, 0.08728215, -0.17823447, 0.10733861, -0.12263825], [-0.01373981, 0.22072753, -0.13233714, 0.04230879, -0.38872594, -0.08226185, 0.1575003 , -0.24028088, 0.07385449, -0.03566224], [ 0.03837798, 0.13289425, 0.04731251, -0.02162105, -0.20273682, -0.05832739, 0.18874544, -0.18923599, -0.00774726, -0.07874206], [-0.03777552, -0.00108435, -0.1691374 , 0.07609938, -0.14852983, -0.0603545 , 0.03725195, -0.07498097, 0.03552598, 0.09923848], [-0.00528672, 0.09970027, -0.18529828, 0.06165883, -0.18298753, -0.18982115, 0.13438234, -0.20037794, 0.04573667, -0.12467755], [ 0.02150192, 0.19947992, 0.07855526, -0.06914394, -0.3714067 , -0.11594264, 0.09797562, -0.26935732, -0.02487007, 0.04328841], [-0.02913332, 0.06816912, -0.15824044, -0.09580886, -0.27435338, -0.27287874, 0.03502546, -0.24163257, -0.0320179 , -0.01755684], [-0.09645626, 0.26521719, -0.02974376, -0.11811897, -0.26733863, -0.11098027, 0.17992583, -0.17170358, 0.13684484, 0.11146528], [ 0.02255949, 0.17668228, -0.18393126, 0.03314231, -0.24813469, -0.15481868, 0.09138933, -0.20842537, 0.14024018, -0.09168213], [ 0.03577811, 0.20097674, 0.01054525, 0.0281911 , -0.37472528, -0.01357181, 0.06443338, -0.3450846 , -0.0717351 , -0.05323324], [-0.07593666, 0.01255987, -0.22572437, -0.00785737, -0.2546096 , -0.07944219, 0.1108809 , -0.24050121, 0.04689084, -0.06361024], [-0.06727166, 0.12973833, -0.13008472, 0.05364395, -0.40278557, -0.05354337, 0.15421708, -0.19313724, 0.034807 , 0.0324662 ], [ 0.01260601, 0.17873257, -0.04235746, 0.04865025, -0.3416106 , -0.18992788, 0.20156004, -0.35417786, -0.02896209, -0.06149546], [-0.21761686, 0.04707823, -0.10519136, -0.02330001, -0.44696686, -0.06795169, -0.0549196 , -0.07158573, 0.02182202, 0.01991212], [-0.04528661, 0.11444949, -0.05191335, -0.00863688, -0.3279702 , -0.08010425, 0.02780636, -0.18188381, 0.01212155, 0.02914584], [-0.07794154, 0.18604933, -0.08827695, 0.00928481, -0.43883204, -0.08850133, 0.17446704, -0.2918891 , 0.07034804, 0.02960903], [-0.16039802, 0.06828902, 0.04284256, 0.0398021 , -0.3121784 , -0.09335808, 0.03558116, -0.11442499, -0.02113784, 0.20230052], [ 0.08210264, -0.05593535, -0.23607862, -0.01597648, -0.14844231, -0.1827047 , 0.0761399 , -0.16778287, -0.07931188, -0.19862217], [ 0.12508929, -0.06630521, -0.17870858, -0.00659649, -0.0866005 , -0.1466447 , 0.11984183, -0.11472597, -0.07381313, -0.25458795], [-0.16582993, 0.12019867, -0.12963267, 0.01781088, -0.3867995 , -0.0929767 , 0.05690374, -0.18877956, -0.09080505, 0.12845796], [-0.04166535, 0.11107228, -0.13034783, 0.11846082, -0.2223956 , -0.04571397, 0.19896089, -0.21065699, 0.02911645, 0.0283865 ], [ 0.05492067, 0.14563766, -0.05812249, -0.11080064, -0.15607221, -0.14377254, 0.11191408, -0.03224962, 0.08461984, -0.10316228], [-0.10184978, 0.1774663 , -0.04064796, -0.00778899, -0.24548785, -0.04007315, 0.15378167, -0.16978785, 0.08297548, 0.15654306], [ 0.03246473, -0.00912588, -0.13997413, 0.01901177, -0.15169062, -0.1233996 , 0.0399372 , -0.14586736, -0.04990026, -0.05761716], [-0.04208291, 0.06552696, -0.1491316 , -0.05165012, -0.30772424, -0.14849369, 0.05366952, -0.20017433, -0.00489577, -0.05179733], [ 0.07999136, 0.1268666 , -0.05649832, -0.11447802, -0.093622 , -0.16343006, 0.18737765, -0.15467761, 0.02121065, -0.18551901], [-0.09349701, 0.15481542, 0.02054488, -0.01257675, -0.45885608, -0.07252534, 0.11548893, -0.25998053, -0.15728104, 0.13945304], [-0.12011181, 0.08322111, -0.13839091, 0.01295412, -0.31343067, -0.12020873, 0.10251326, -0.23772089, -0.07965665, 0.12370199], [ 0.04355327, 0.24670643, -0.10936423, -0.11540431, -0.47812563, -0.11112159, 0.171045 , -0.27362657, 0.05597766, -0.08047073], [-0.07981615, 0.1407882 , -0.05179737, 0.11515354, -0.39608997, 0.12418469, 0.12287964, -0.35551202, 0.05395594, 0.106438 ], [ 0.0495141 , 0.16530135, -0.07892615, -0.04727711, -0.17437093, -0.10992847, 0.11209463, -0.17067385, 0.15505184, -0.04079971], [ 0.0399215 , 0.21293016, -0.0784945 , -0.05903067, -0.4351067 , -0.0897992 , 0.19701788, -0.29322943, -0.02085371, -0.08934651], [-0.03191719, 0.09556475, -0.02399754, -0.07302108, -0.31528878, -0.01236389, 0.07309026, -0.25867784, -0.02151406, 0.04411262], [ 0.1144869 , -0.07487808, -0.09897751, -0.03669246, -0.07122576, -0.06312343, 0.0934795 , -0.08661507, -0.06738903, -0.19981778], [ 0.05407554, 0.29886216, -0.04840349, -0.08312691, -0.40975356, -0.30462724, 0.18296465, -0.36522552, -0.01246376, -0.03964518], [ 0.0863948 , 0.27042687, 0.07862431, -0.07175243, -0.23997524, -0.19032082, 0.24878645, -0.2724223 , 0.04755342, -0.08436754], [ 0.01792005, 0.1303249 , 0.00751538, -0.04509098, -0.24403217, -0.10693009, 0.12554878, -0.22423464, -0.01941013, -0.05743247], [ 0.10427958, 0.04477797, -0.21033043, -0.0367108 , -0.23080736, -0.14509909, 0.09560896, -0.15382816, 0.05470648, -0.18330926], [ 0.06824732, -0.00402825, -0.15888713, -0.07463223, -0.16558173, -0.08697934, 0.0755794 , -0.16385713, -0.01368479, -0.2683872 ], [-0.18796708, 0.08231901, -0.10495865, -0.02746309, -0.4590808 , -0.09270471, 0.11427653, -0.18020634, -0.08563931, 0.114185 ], [-0.09791254, 0.03529916, -0.01679311, 0.04053622, -0.28840482, -0.05182896, 0.09699607, -0.0433149 , -0.04200018, -0.01914415], [-0.11433886, 0.13621204, -0.14640698, 0.09180373, -0.340676 , -0.04544587, 0.16125873, -0.22271709, -0.01954299, 0.10093789], [-0.10057668, 0.16380787, -0.11834459, -0.00495677, -0.43168932, -0.06923562, 0.10913345, -0.17641947, -0.02874756, 0.06140285], [-0.06116604, 0.13700888, -0.03211374, -0.01969503, -0.37303066, -0.01124507, 0.18596011, -0.20456824, -0.04335499, 0.06383611], [-0.06899989, 0.13957006, -0.15314595, 0.00410667, -0.41311985, -0.06501595, 0.21060655, -0.31016892, -0.09425378, 0.05115268], [ 0.00838745, 0.1184913 , -0.14133814, 0.11310323, -0.19835803, -0.11010459, 0.20528942, -0.3756318 , 0.00677975, -0.09190771], [-0.01702203, 0.23581095, -0.11833926, -0.05140661, -0.3833909 , -0.2917281 , 0.15348044, -0.2923426 , -0.01539594, -0.07770756], [ 0.00376625, 0.14644173, -0.22676727, -0.00695115, -0.25346804, -0.08446033, 0.15793926, -0.28682148, 0.06036796, -0.05657917], [-0.15886536, 0.22905974, -0.08670876, 0.02516738, -0.44418338, -0.05140979, 0.12132007, -0.24890812, -0.12002567, 0.17170556], [-0.00594457, 0.1702473 , -0.04271276, 0.03786185, -0.33728057, -0.10691148, 0.14885393, -0.2579561 , 0.03387661, -0.12696798], [ 0.12660183, 0.0719447 , -0.0284185 , -0.01238126, -0.1789788 , -0.14557405, 0.11981218, -0.27433586, -0.04192446, -0.15991354], [ 0.01714662, 0.07349592, -0.13597934, 0.07258661, -0.17497018, -0.1185775 , 0.12789062, -0.21281385, 0.0407939 , -0.12549406], [-0.05560762, 0.09808804, -0.03752606, 0.03666614, -0.34828562, -0.08055064, 0.14122654, -0.35993794, -0.03219686, -0.07910113]], dtype=float32)>, 2: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[ 0.1058839 , 0.15189251, -0.06218533, -0.07259189, -0.21885905, -0.18747613, 0.19386706, -0.20872134, -0.06631209, -0.17281003], [-0.06724259, 0.16423771, -0.09042411, 0.06249507, -0.35925558, -0.08052117, 0.06635582, -0.21773912, -0.02575091, -0.02225357], [-0.12099646, 0.18426058, -0.14509366, 0.03074409, -0.43172374, -0.14523304, 0.11278548, -0.263721 , 0.08033177, 0.11983414], [-0.02189111, 0.19432685, -0.06076544, -0.11341121, -0.31391418, -0.13985458, 0.01988082, -0.27448916, -0.03138182, -0.09098685], [ 0.05217094, 0.10092196, 0.04062135, 0.00210309, -0.2396484 , -0.09674437, 0.11046519, -0.24062449, -0.05751724, -0.1498842 ], [ 0.05167808, 0.17948417, 0.02552646, -0.15288836, -0.2610386 , -0.1887011 , 0.2494685 , -0.25246763, -0.00140664, -0.05116289], [-0.06461833, 0.18794754, 0.01723582, 0.01763999, -0.27383786, 0.01134103, 0.00840673, -0.19577272, 0.04369798, 0.02309319], [ 0.1343988 , 0.02689996, -0.09034813, -0.14479995, -0.11877482, -0.20683356, 0.1408261 , -0.07485476, -0.00536738, -0.10480646], [ 0.14985448, 0.01248868, -0.18233272, -0.04805535, -0.12569216, -0.10754709, 0.10563765, -0.14238729, -0.04125558, -0.30692846], [ 0.0174889 , 0.02731286, -0.09870998, -0.01109812, -0.22202691, -0.15064979, 0.0707394 , -0.16593488, -0.04917879, -0.09892345], [-0.02009851, 0.1603423 , -0.03657382, -0.04635969, -0.376585 , -0.15620278, 0.09772854, -0.331561 , -0.08288915, 0.03283934], [-0.12873377, 0.18706101, -0.08856387, 0.01241396, -0.33643866, -0.04564854, 0.19829091, -0.14762393, 0.05415195, 0.09738437], [ 0.02919928, 0.2304655 , -0.118172 , -0.02557866, -0.20723924, -0.15870155, 0.10659584, -0.29492664, 0.08007928, -0.02028007], [ 0.00755258, 0.13715193, -0.04229313, -0.02631894, -0.3507014 , -0.09775324, 0.08749582, -0.28658673, -0.06305373, -0.06657201], [ 0.01293204, 0.34690908, 0.04428834, 0.06181278, -0.39632797, -0.16290376, 0.3403181 , -0.3575364 , -0.04994334, -0.04745176], [-0.1721677 , 0.15409604, 0.01526155, 0.00422953, -0.40304124, -0.20704678, 0.18548846, -0.1347191 , -0.1029956 , 0.12649345], [ 0.00416036, 0.0911807 , -0.14212427, -0.02980888, -0.41507924, -0.08092883, 0.09573203, -0.3632136 , -0.05752493, -0.10265536], [ 0.08953582, 0.1465146 , -0.0943587 , 0.01540706, -0.1591577 , -0.10180154, 0.26632407, -0.14738992, 0.05417037, -0.19402057], [ 0.09650303, 0.1520698 , -0.12804723, -0.00549375, -0.13437587, -0.10345934, 0.11017533, -0.07745089, 0.07155803, -0.18119216], [-0.11120163, 0.05948977, -0.12460698, 0.07350107, -0.2501851 , -0.10606505, 0.03863706, -0.19143325, 0.05141674, 0.05777346], [-0.06015262, 0.1133145 , 0.02975767, 0.10413057, -0.4479565 , -0.19103777, 0.16624743, -0.23277 , -0.00409902, 0.16262262], [-0.02920761, 0.05220562, -0.01984085, -0.09213866, -0.1952639 , -0.2091724 , 0.09540902, -0.25029248, -0.01369284, -0.07937889], [ 0.08655809, 0.04477531, -0.14414985, -0.13330539, -0.14485738, -0.25247657, 0.161552 , -0.10539306, -0.07818886, -0.22021733], [ 0.05988639, 0.09053953, -0.20782766, 0.0799824 , -0.36067218, -0.11564009, 0.19533472, -0.17254165, -0.03013015, -0.05901307], [ 0.08023848, 0.24398087, -0.03037388, 0.03198615, -0.257546 , -0.0717455 , 0.11799999, -0.2501925 , 0.08931698, 0.01646613], [ 0.01829973, 0.14714168, 0.08777735, -0.04686411, -0.3246078 , -0.11189084, 0.16163842, -0.27373928, -0.05865945, -0.12625104], [ 0.01551664, 0.15430057, -0.06801915, 0.04362161, -0.23502748, -0.10225095, 0.20541084, -0.3872087 , -0.03092569, -0.10081216], [-0.0215794 , 0.2048584 , -0.1525181 , -0.07231435, -0.35670137, -0.20132679, 0.16621551, -0.23583391, 0.03286067, 0.00162398], [-0.16785663, 0.14515236, -0.17006704, -0.07688057, -0.33053195, -0.30842096, 0.08084383, -0.16129194, -0.01640936, 0.00512123], [-0.01012924, 0.18210989, -0.11657718, 0.01039248, -0.3604356 , -0.06368212, 0.17732035, -0.32601178, 0.00564714, 0.079637 ], [-0.02081634, 0.1340083 , -0.04119257, -0.11712583, -0.24917075, -0.14567725, 0.11910035, -0.19569927, -0.0334375 , -0.02956842], [ 0.01040663, -0.03312722, -0.16251259, -0.06878473, -0.29412496, -0.12144031, 0.05180205, -0.18370074, -0.08275174, -0.17228474], [ 0.16995256, 0.15084882, 0.06516901, -0.12424685, -0.28156468, -0.12124009, 0.1919595 , -0.3235635 , -0.16550806, -0.2237745 ], [ 0.14201358, 0.29564446, -0.05666097, -0.11691651, -0.38729382, -0.12142777, 0.18535727, -0.4062221 , -0.02861272, -0.08876766], [-0.13636869, 0.14419225, -0.12272403, -0.12043885, -0.25843346, -0.2624993 , 0.11443155, -0.10935706, 0.07051769, -0.07280853], [ 0.1506514 , 0.231885 , 0.05595114, -0.00167914, -0.20262057, -0.12446204, 0.20270723, -0.24610087, 0.13529469, -0.02456838], [ 0.05419209, 0.14951019, -0.07615753, -0.01649981, -0.21320246, -0.10147403, 0.21325038, -0.3668218 , -0.06149146, -0.04654743], [-0.106455 , 0.21365353, -0.0212086 , 0.00301608, -0.42738837, -0.03399226, 0.21012497, -0.27643734, -0.04460012, 0.09438403], [-0.20142697, 0.2490167 , -0.0610823 , 0.03178958, -0.46005037, -0.11769256, 0.12650084, -0.11382679, 0.05542684, 0.20377554], [ 0.06785811, 0.20105124, -0.06821625, -0.07295625, -0.3634079 , -0.14018053, 0.1070195 , -0.32426855, -0.02956013, -0.08126846], [-0.16207823, 0.10067027, -0.17504756, 0.05559444, -0.41034544, -0.10927348, 0.10540979, -0.1979583 , -0.03317957, 0.13901047], [-0.06511355, 0.12664205, 0.01013242, -0.1038757 , -0.25350553, -0.07925887, -0.01203869, -0.04564004, 0.0292751 , 0.04884695], [ 0.03302544, 0.25231293, 0.02099521, 0.08474679, -0.34921703, -0.09708492, 0.11735824, -0.3118752 , -0.04610436, -0.00284932], [-0.05722202, 0.23294924, 0.06066301, 0.0403684 , -0.30209425, -0.11734634, 0.0649392 , -0.22963187, -0.00225438, 0.09428503], [ 0.03891737, 0.2118478 , -0.02974258, -0.0014833 , -0.41377792, -0.10676192, 0.09491015, -0.30824155, -0.05244756, -0.13116692], [ 0.0387698 , 0.12415141, -0.03787622, -0.01751001, -0.23610467, -0.101076 , 0.10127868, -0.21973278, 0.10053106, -0.11506974], [-0.16079637, 0.12975846, -0.05280735, 0.0171553 , -0.31617314, -0.05437503, 0.11548977, -0.22737938, 0.0040158 , 0.1828981 ], [-0.00830905, 0.06583936, 0.05894554, -0.06423548, -0.25190187, -0.16529903, 0.11135682, -0.20058249, -0.12155165, 0.0516177 ], [ 0.11394778, 0.13319941, -0.14848323, -0.00475912, -0.14136326, -0.08358698, 0.11641169, -0.11600958, 0.04616823, -0.24377736], [-0.04899902, 0.02212415, -0.07927552, 0.03247839, -0.28332067, -0.05081623, 0.10070023, -0.18853599, -0.04962738, -0.07133716], [ 0.06972719, 0.16070195, -0.06629859, 0.02116043, -0.2982213 , -0.06711708, 0.15437523, -0.31704813, 0.09814174, 0.02899931], [ 0.10892028, 0.06948914, 0.00259475, -0.14409684, -0.05081703, -0.10783089, 0.11664214, -0.1612539 , 0.03825954, -0.0974022 ], [ 0.11645016, 0.09103441, -0.08829162, -0.03281923, -0.11439721, -0.04074144, 0.13266224, -0.16234764, 0.0108225 , -0.1980096 ], [ 0.04067405, 0.28472447, 0.06697901, 0.00244627, -0.44801572, -0.17432919, 0.14524695, -0.33887076, -0.04374426, -0.01839055], [-0.07228023, 0.0094237 , -0.0114438 , 0.05621499, -0.1906197 , -0.10239108, 0.1188737 , -0.06362449, -0.04171722, -0.04671538], [ 0.11314879, 0.14128837, -0.10640095, -0.00203198, -0.1638219 , -0.14681938, 0.15794587, -0.13135739, -0.00294857, -0.2193684 ], [-0.189158 , 0.14615512, -0.1951129 , 0.01067956, -0.46317485, -0.12231541, 0.17576897, -0.13713211, 0.05281383, 0.05961443], [ 0.13483778, 0.15122853, -0.06812572, 0.02223442, -0.31159732, -0.14329562, 0.13942383, -0.3621065 , 0.0194601 , -0.17312187], [-0.07476285, 0.056327 , -0.1509085 , 0.12273326, -0.25021672, -0.09043095, 0.14668559, -0.15018652, -0.03156234, 0.02678441], [-0.00460077, 0.1331944 , -0.14329152, -0.05613213, -0.3445416 , -0.16092777, 0.1215454 , -0.25038692, 0.04866181, -0.1493005 ], [ 0.01425789, 0.12178193, 0.04116632, -0.08268491, -0.24872722, -0.13918278, 0.05446562, -0.19651294, 0.02265346, -0.06737643], [-0.01926323, 0.1012551 , -0.05458277, 0.046477 , -0.2919128 , -0.13312258, 0.08117634, -0.38422066, -0.01930849, 0.00814056], [ 0.11184652, 0.30679387, 0.09332582, -0.01897855, -0.2433966 , -0.15644044, 0.19370434, -0.18590672, 0.05792751, -0.14004235], [-0.01329161, 0.13452995, -0.22033097, 0.07828833, -0.29261756, -0.07592015, 0.13703327, -0.16651678, 0.01105583, -0.0150648 ]], dtype=float32)>, 3: <tf.Tensor: shape=(64, 10), dtype=float32, numpy= array([[ 2.51155347e-03, 1.45447791e-01, 6.41264021e-04, 1.91463828e-02, -1.85839653e-01, -1.03563413e-01, 6.61028773e-02, -1.90979049e-01, 6.06872290e-02, -1.41341135e-01], [-4.54837829e-02, 1.08597815e-01, -1.17722288e-01, 1.42128780e-01, -2.34650105e-01, -1.13956973e-01, 1.98587418e-01, -2.94802964e-01, -1.37891099e-02, 4.00971621e-04], [-1.90911219e-02, 1.18707336e-01, -3.10591292e-02, 4.01906706e-02, -3.41950208e-01, -2.94433124e-02, 1.32617578e-01, -2.99225211e-01, -3.27008069e-02, -1.80488713e-02], [-4.05668318e-02, 1.91518039e-01, -2.78929994e-02, 6.00643232e-02, -1.37896061e-01, -9.06022638e-02, 1.69026867e-01, -3.06107998e-01, 7.56823719e-02, 6.13778904e-02], [-5.93572706e-02, 1.76578969e-01, -1.66006207e-01, 5.42817116e-02, -3.62876147e-01, -1.09547049e-01, 1.69198722e-01, -2.94477433e-01, 1.92826986e-02, 1.70004927e-03], [-1.40251517e-02, 1.31867789e-02, -3.96135375e-02, -9.33120027e-02, -2.40361691e-01, -1.00401580e-01, 1.22912303e-02, -1.80318967e-01, -8.82146955e-02, -1.93105519e-01], [ 1.16194248e-01, 1.79252625e-02, -1.44289330e-01, -4.96181250e-02, -1.14380077e-01, -9.49703977e-02, 9.81448591e-02, -1.50415093e-01, -3.23846750e-02, -2.56290972e-01], [-1.55512661e-01, 1.78086460e-01, -3.19233388e-02, 2.66162679e-02, -4.40953791e-01, 8.28160048e-02, 9.33465362e-02, -2.50108629e-01, 3.17218527e-02, 1.32439479e-01], [-6.79045916e-02, 1.01774186e-01, -1.06924191e-01, -3.54708657e-02, -2.07453206e-01, -3.10046911e-01, 1.38181359e-01, -2.24454358e-01, -5.50230630e-02, -5.98306395e-02], [-7.62961432e-02, 1.54165313e-01, 2.80987248e-02, -1.35838151e-01, -3.44427884e-01, -1.93413466e-01, 7.58136064e-02, -1.11874819e-01, 1.76186338e-02, 1.20480917e-02], [ 1.22831598e-01, 1.30915672e-01, 7.43995085e-02, -7.53342807e-02, -2.06910223e-01, -1.84539944e-01, 2.61359870e-01, -2.82187372e-01, -6.73582032e-03, 1.24582648e-03], [ 2.05421969e-02, 2.17275023e-01, -1.98052749e-01, -7.11035356e-03, -3.69731605e-01, -1.85717657e-01, 4.95720804e-02, -1.15070224e-01, 9.22817066e-02, -8.54254588e-02], [ 6.88252151e-02, 1.44817978e-01, -1.33312225e-01, -5.94893545e-02, -2.51493931e-01, -1.20750599e-01, 1.42469034e-01, -3.34877908e-01, -1.80790052e-02, -2.16673821e-01], [-1.82654113e-01, 1.76363677e-01, -1.04193613e-01, -1.50571056e-02, -5.33530712e-01, -6.55154735e-02, 2.52551317e-01, -2.92810172e-01, 2.38716491e-02, 1.04600534e-01], [ 2.09688932e-01, 6.80498332e-02, 4.03034799e-02, -6.70600384e-02, -9.22301486e-02, -3.91195267e-02, 2.45474190e-01, -1.92359731e-01, 1.75092369e-03, -1.95264012e-01], [-2.23329328e-02, 1.45669162e-01, -1.52139455e-01, 5.34053110e-02, -2.56344557e-01, -9.23609883e-02, 1.19958594e-01, -2.34045699e-01, 3.79388183e-02, -7.73888920e-03], [-1.97878480e-02, 1.73301339e-01, -9.23557431e-02, 8.06293264e-02, -3.48612636e-01, -4.22735885e-03, 2.06312776e-01, -3.54612112e-01, 4.30748984e-02, 2.62675956e-02], [ 2.90626474e-03, 1.41105443e-01, -9.56210792e-02, 2.42822953e-02, -1.43771321e-01, -5.69432080e-02, 1.31530389e-01, -1.60310119e-01, 8.27911496e-02, -3.81680019e-02], [ 9.71606374e-03, 9.51089263e-02, -1.06061004e-01, -3.78770605e-02, -4.01138932e-01, -1.43262267e-01, 1.28680661e-01, -2.61636645e-01, -1.15659922e-01, -1.11245878e-01], [ 5.21949939e-02, 1.50751829e-01, 9.35174376e-02, 3.17564085e-02, -3.50595593e-01, -1.38468161e-01, 1.85263827e-01, -3.03996593e-01, -5.42629510e-03, 7.91802928e-02], [ 1.37041256e-01, 1.69497833e-01, -8.63870382e-02, 1.57636702e-02, -1.79991171e-01, -9.08993408e-02, 1.56794131e-01, -1.47339672e-01, 2.86077783e-02, -2.16370448e-01], [-2.69663520e-03, 2.21826449e-01, 5.43841496e-02, -1.17137842e-01, -2.97215730e-01, -8.82113725e-02, 1.05724975e-01, -2.20105618e-01, 4.52305302e-02, 3.74987908e-03], [-9.54074338e-02, 9.03669074e-02, -1.19427562e-01, -3.28763574e-03, -2.99989522e-01, -4.01741937e-02, 1.19693980e-01, -3.00709546e-01, -1.82122067e-02, 6.16809502e-02], [-9.99720395e-02, 1.39634579e-01, -2.00303137e-01, 6.65970296e-02, -3.51577580e-01, -9.84932706e-02, 9.02529806e-02, -2.44839266e-01, 1.13900602e-02, 6.42385334e-02], [ 1.44478194e-02, 1.83509141e-01, -6.67656586e-02, -4.61032502e-02, -4.40125525e-01, -7.73666874e-02, 3.11786264e-01, -4.44831997e-01, -2.93702632e-02, 4.22365218e-03], [ 1.32541731e-01, 1.82956174e-01, -1.04927093e-01, 1.22469440e-02, -1.27488539e-01, -6.61835223e-02, 1.12031206e-01, -9.84050632e-02, 9.86539871e-02, -2.05507338e-01], [-6.50153607e-02, 1.74549997e-01, -1.71862602e-01, 6.14198335e-02, -4.02740180e-01, -4.20516059e-02, 2.40440428e-01, -2.76678443e-01, -5.84294945e-02, 1.08706802e-01], [ 7.48759061e-02, 2.10499942e-01, 7.57246763e-02, -3.62860188e-02, -3.67509723e-01, -7.15096444e-02, 8.84700716e-02, -3.12674284e-01, -9.51475725e-02, -5.33458181e-02], [-7.26057142e-02, 9.15311202e-02, -4.78885621e-02, -1.29553005e-02, -2.53760219e-01, -1.62764773e-01, 7.15883225e-02, -1.82869673e-01, -7.55081400e-02, -3.04650366e-02], [-1.08026592e-02, 1.12620339e-01, 1.36063978e-01, 7.70925432e-02, -1.93178192e-01, 1.07710674e-01, 1.66693091e-01, -2.04272971e-01, -4.38663363e-03, 6.33232668e-02], [ 1.56896949e-01, 8.96635503e-02, 4.47561368e-02, -7.04529062e-02, -6.03420623e-02, -5.07256314e-02, 1.84834778e-01, -2.29554027e-01, 2.17712931e-02, -1.52052462e-01], [-1.15724243e-01, 8.04387107e-02, -5.88226505e-02, 8.92069191e-04, -2.94122517e-01, -1.87400073e-01, 2.21029714e-01, -1.70096382e-01, 2.51032859e-02, 1.31270349e-01], [ 1.54940831e-02, 1.65156215e-01, 4.22797874e-02, -3.61752585e-02, -2.75962710e-01, -1.24710202e-01, 8.39878917e-02, -2.54829884e-01, 1.27165876e-02, 3.31448913e-02], [ 2.08573253e-03, 1.30548522e-01, 1.03515625e-01, -5.84256947e-02, -2.24049121e-01, -8.96962434e-02, 2.11826205e-01, -2.76590139e-01, -5.25184534e-02, 5.65478951e-02], [ 1.21045813e-01, 3.13128859e-01, 9.14186686e-02, -5.13909534e-02, -2.90747494e-01, -1.41330570e-01, 1.61255151e-01, -4.09655958e-01, 3.58102247e-02, 2.64274161e-02], [ 1.41505450e-02, 8.61316472e-02, -1.45660698e-01, 5.86940274e-02, -3.43586266e-01, -8.44506472e-02, 1.70025736e-01, -2.38322318e-01, -4.34658676e-03, -2.87827849e-03], [ 7.29348212e-02, 1.93091661e-01, -1.56394571e-01, 4.55366895e-02, -3.35400045e-01, -1.32662982e-01, 7.06198215e-02, -1.37381166e-01, 8.84873420e-02, -1.14571638e-01], [-3.92792933e-02, 1.48667991e-01, -9.71383378e-02, -6.74230531e-02, -3.02699536e-01, -2.55341195e-02, 1.27064362e-01, -2.09095284e-01, 1.01266101e-01, 1.40049338e-01], [ 1.23310566e-01, 7.76651576e-02, -5.40870428e-03, 4.73390855e-02, -2.56887197e-01, 1.34499185e-02, 1.85226709e-01, -2.55586505e-01, 4.35908213e-02, -3.50753143e-02], [-8.64604563e-02, 2.31015921e-01, -5.34078926e-02, -5.69486693e-02, -4.90515888e-01, -1.08600266e-01, 1.19341165e-01, -3.48818719e-01, -1.02463067e-02, 5.82626238e-02], [ 2.33721081e-02, 8.51243734e-04, -1.55263841e-01, -4.45500016e-02, -2.12267876e-01, -1.53808981e-01, 6.14735037e-02, -1.64876774e-01, -1.04674906e-01, -1.51574746e-01], [-1.39495209e-01, 1.52735859e-01, -1.02774106e-01, 3.62434834e-02, -3.99802774e-01, -4.73229587e-02, 2.26968199e-01, -2.60911047e-01, -4.54405621e-02, 1.17595874e-01], [ 1.58262476e-02, 8.54684860e-02, -1.06508911e-01, 1.02037542e-01, -3.03655088e-01, -7.41694570e-02, 2.05974013e-01, -2.92493582e-01, -1.05211154e-01, 5.25907800e-03], [-8.26321095e-02, 1.80413112e-01, -1.01069041e-01, -3.26163694e-02, -4.50304151e-01, -1.06844909e-01, 5.95432222e-02, -2.05993459e-01, 8.68424997e-02, 9.42370202e-03], [ 4.29255553e-02, 1.63784251e-01, -2.18889564e-02, 6.25932515e-02, -1.61837101e-01, -2.08140716e-01, 2.36108944e-01, -2.10580200e-01, -8.19268376e-02, -8.67268592e-02], [ 1.06723763e-01, 1.73626572e-01, -7.36150444e-02, -3.47835794e-02, -2.27102175e-01, -2.34541237e-01, 1.46386385e-01, -2.80957967e-01, -4.40967083e-03, -5.37140332e-02], [ 2.75464915e-02, 1.41459256e-01, -3.87002230e-02, 2.73079798e-02, -3.48300725e-01, -1.21690288e-01, 2.27318466e-01, -3.45179290e-01, 2.37728134e-02, -4.03352082e-02], [-1.76691502e-01, 1.37643158e-01, -5.67638800e-02, -5.56118786e-04, -3.88640523e-01, -1.25400782e-01, 2.01528460e-01, -2.48365819e-01, -1.57399729e-01, 1.01713941e-01], [-1.59693748e-01, 2.09404677e-01, -1.04986534e-01, -1.48359641e-01, -2.98504502e-01, -2.92371094e-01, 5.03033698e-02, -1.31363750e-01, 1.03016965e-01, 2.24144422e-02], [ 5.60386404e-02, 1.35745630e-01, 3.90878096e-02, 7.47692585e-02, -2.68035203e-01, -7.92922825e-02, 2.57956684e-01, -3.78247499e-01, -6.12850487e-03, -1.08967513e-01], [-3.56525183e-03, 1.20683648e-01, 7.16836527e-02, -4.75444421e-02, -3.20046216e-01, -1.89872205e-01, 1.36546373e-01, -2.40315855e-01, -8.23161006e-02, 9.24457610e-02], [ 8.29515085e-02, -1.42475218e-03, -1.37439638e-01, -7.92527385e-03, -9.10671651e-02, -1.04378976e-01, 9.12457854e-02, -1.41347721e-01, -5.19168749e-02, -2.03986824e-01], [ 2.84444056e-02, 1.29459754e-01, -1.53349116e-02, -8.56753960e-02, -3.15394014e-01, -1.11945629e-01, 6.06864467e-02, -3.40235710e-01, -3.83457541e-02, -9.85145792e-02], [ 1.28373414e-01, -4.11517769e-02, -1.60181478e-01, -7.12700635e-02, -1.03279904e-01, -7.15017915e-02, 7.57745504e-02, -1.26075238e-01, -4.52508181e-02, -2.89755970e-01], [-6.51197582e-02, 2.04959273e-01, -1.04472905e-01, 3.50410938e-02, -2.49693528e-01, -1.15549982e-01, 3.61202434e-02, -1.73137397e-01, -1.81247666e-02, -3.57590988e-03], [-5.83130457e-02, 1.40513688e-01, -1.33702904e-01, 2.82108374e-02, -4.03776646e-01, -9.68370736e-02, 1.41329780e-01, -1.68025374e-01, 3.46826874e-02, 2.28778720e-02], [-3.75167355e-02, 8.63219798e-02, -1.85671210e-01, 2.99670994e-02, -2.33271241e-01, -6.28029704e-02, 1.62515551e-01, -1.93014055e-01, 6.13552369e-02, -1.79618783e-02], [ 3.00599448e-02, 1.02386631e-01, 8.53579044e-02, 1.64787471e-03, -1.77172825e-01, -1.43747658e-01, 5.66788465e-02, -1.46428332e-01, -1.47961304e-02, -4.46532555e-02], [-1.10645309e-01, 1.65324107e-01, 4.55538183e-03, 2.99948044e-02, -3.46696526e-01, 1.56221874e-02, 3.43421698e-02, -1.53912872e-01, -1.63781494e-02, 5.61783984e-02], [-5.74554652e-02, 1.73873678e-01, 2.39674598e-02, 5.73134795e-02, -3.39531004e-01, -2.35163961e-02, 8.80995244e-02, -1.96591303e-01, 4.05878499e-02, 3.93691286e-02], [-6.33653328e-02, 1.92787185e-01, -1.58524215e-01, -2.15648673e-03, -3.60980183e-01, -1.75777331e-01, 1.52091727e-01, -2.15706527e-01, 8.36684853e-02, 3.28490287e-02], [-1.28059506e-01, 1.10793650e-01, -1.03594311e-01, 1.62330829e-02, -3.45485270e-01, -1.49970520e-02, 1.38308808e-01, -2.61050552e-01, 3.11166197e-02, 2.95418501e-02], [-1.12311147e-01, 2.21623197e-01, 4.02755402e-02, 1.26273036e-01, -3.69358778e-01, -3.77379954e-02, 1.73841923e-01, -3.46165776e-01, 2.09579170e-02, 1.62290595e-02], [-7.96634629e-02, 8.21764618e-02, -1.46908149e-01, -1.15463018e-01, -2.95011848e-01, -2.38702789e-01, 4.32551615e-02, -2.11507469e-01, 3.09027694e-02, -1.27929717e-01]], dtype=float32)> } } 2024-01-11 18:14:53.554141: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
復元された関数の呼び出しは、保存されたモデル(tf.keras.Model.predict
)に対するフォワードパスです。読み込まれた関数をトレーニングし続ける場合はどうでしょうか。または読み込まれた関数をより大きなモデルに埋め込むには?一般的には、この読み込まれたオブジェクトを Keras レイヤーにラップして行うことができます。幸いにも、TF Hub には、以下に示すとおり、この目的に使用できる hub.KerasLayer
が用意されています。
import tensorflow_hub as hub
def build_model(loaded):
x = tf.keras.layers.Input(shape=(28, 28, 1), name='input_x')
# Wrap what's loaded to a KerasLayer
keras_layer = hub.KerasLayer(loaded, trainable=True)(x)
model = tf.keras.Model(x, keras_layer)
return model
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
model = build_model(loaded)
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
model.fit(train_dataset, epochs=2)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') 2024-01-11 18:14:54.395157: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:553] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed. Epoch 1/2 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 INFO:tensorflow:Collective all_reduce tensors: 6 all_reduces, num_devices = 4, group_size = 4, implementation = CommunicationImplementation.NCCL, num_packs = 1 235/235 [==============================] - 4s 7ms/step - loss: 0.3436 - sparse_categorical_accuracy: 0.9023 Epoch 2/2 235/235 [==============================] - 2s 7ms/step - loss: 0.1126 - sparse_categorical_accuracy: 0.9682
上記の例では、hub.KerasLayer
は tf.saved_model.load()
から読み込まれた結果を、別のモデルの構築に使用できる Keras レイヤーにラップしています。転移学習を行う際に非常に便利な手法です。
どの API を使用すべきですか?
保存においては、Keras モデルを使用している場合は、低レベル API が実現できる追加の制御が必要でない限り、Keras の Model.save
API を使用します。保存しているものが Keras モデルでない場合は、低レベル API の tf.saved_model.save
しか使用できません。
読み込みにおいては、使用する API はモデルの読み込みから得ようとしているものによって異なります。Keras モデルを使用できない場合(または使用したくない場合)は、tf.saved_model.load
を使用し、使用できる場合は tf.keras.models.load_model
を使用します。Keras モデルを保存した場合にのみ、Keras モデルを読み込めることに注意してください。
API を混在させることも可能です。model.save
で Keras モデルを保存し、低レベルの tf.saved_model.load
API を使用して、非 Keras モデルを読み込むことができます。
model = get_model()
# Saving the model using Keras `Model.save`
model.save(saved_model_path)
another_strategy = tf.distribute.MirroredStrategy()
# Loading the model using the lower-level API
with another_strategy.scope():
loaded = tf.saved_model.load(saved_model_path)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
ローカルデバイスからの読み込みまたは保存
ローカル I/O デバイスから読み込みと保存を行い、リモートデバイスでトレーニングする場合(Cloud TPU を使用する場合など)、tf.saved_model.SaveOptions
と tf.saved_model.LoadOptions
に experimental_io_device
を使用して、I/O デバイスを localhost
に設定する必要があります。以下に例を示します。
model = get_model()
# Saving the model to a path on localhost.
saved_model_path = '/tmp/tf_save'
save_options = tf.saved_model.SaveOptions(experimental_io_device='/job:localhost')
model.save(saved_model_path, options=save_options)
# Loading the model from a path on localhost.
another_strategy = tf.distribute.MirroredStrategy()
with another_strategy.scope():
load_options = tf.saved_model.LoadOptions(experimental_io_device='/job:localhost')
loaded = tf.keras.models.load_model(saved_model_path, options=load_options)
INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3') INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')
警告
Keras モデルを特定の方法で作成してから、トレーニングする前に保存するという、以下のような特別なケースがあります。
class SubclassedModel(tf.keras.Model):
"""Example model defined by subclassing `tf.keras.Model`."""
output_name = 'output_layer'
def __init__(self):
super(SubclassedModel, self).__init__()
self._dense_layer = tf.keras.layers.Dense(
5, dtype=tf.dtypes.float32, name=self.output_name)
def call(self, inputs):
return self._dense_layer(inputs)
my_model = SubclassedModel()
try:
my_model.save(saved_model_path)
except ValueError as e:
print(f'{type(e).__name__}: ', *e.args)
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built. ValueError: Model <__main__.SubclassedModel object at 0x7f7f09dc29a0> cannot be saved either because the input shape is not available or because the forward pass of the model is not defined.To define a forward pass, please override `Model.call()`. To specify an input shape, either call `build(input_shape)` directly, or call the model on actual data using `Model()`, `Model.fit()`, or `Model.predict()`. If you have a custom training step, please make sure to invoke the forward pass in train step through `Model.__call__`, i.e. `model(inputs)`, as opposed to `model.call()`.
SavedModel は tf.function
をトレースする際に生成される tf.types.experimental.ConcreteFunction
オブジェクトを保存します(詳細は、グラフと tf.function の基本ガイドの関数はいつトレースしますか? をご覧ください)。このような ValueError
が発生した場合、Model.save
がトレースされた ConcreteFunction
を見つけられなかったか作成できなかったことが原因です。
注意: 少なくとも 1 つの ConcreteFunction
がない場合にモデルを保存しないことをお勧めします。そうでない場合、低レベル API は、ConcreteFunction
シグネチャのない状態で SavedModel を生成してしまうためです(SavedModel 形式については、こちらをご覧ください)。以下に例を示します。
tf.saved_model.save(my_model, saved_model_path)
x = tf.saved_model.load(saved_model_path)
x.signatures
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.SubclassedModel object at 0x7f7f09dc29a0>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7f7f09c0f250>, because it is not built. WARNING:tensorflow:Skipping full serialization of Keras layer <keras.src.layers.core.dense.Dense object at 0x7f7f09c0f250>, because it is not built. INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets _SignatureMap({})
通常、モデルのフォワードパス(call
メソッド)は、モデルが Keras の Model.fit
メソッドを通じて初めて呼び出されたときに、自動的にトレースされます。また、最初のレイヤーを tf.keras.layers.InputLayer
などにして、input_shape
キーワード引数に渡すことで入力形状を設定している場合、Keras の Sequential API と Functional API によって ConcreteFunction
が生成されることもあります。
モデルにトレース済みの ConcreteFunction
が存在するかを確認するには、Model.save_spec
が None
になっていることを確認します。
print(my_model.save_spec() is None)
True
tf.keras.Model.fit
を使ってモデルをトレーニングし、save_spec
が定義され、モデルの保存が機能するかを確認しましょう。
BATCH_SIZE_PER_REPLICA = 4
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * mirrored_strategy.num_replicas_in_sync
dataset_size = 100
dataset = tf.data.Dataset.from_tensors(
(tf.range(5, dtype=tf.float32), tf.range(5, dtype=tf.float32))
).repeat(dataset_size).batch(BATCH_SIZE)
my_model.compile(optimizer='adam', loss='mean_squared_error')
my_model.fit(dataset, epochs=2)
print(my_model.save_spec() is None)
my_model.save(saved_model_path)
Epoch 1/2 7/7 [==============================] - 1s 2ms/step - loss: 4.7761 Epoch 2/2 7/7 [==============================] - 0s 2ms/step - loss: 4.4889 False INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5, 5), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09cb3220>, 140183601399632), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}). INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(5,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f7f09a6c790>, 140183601400208), {}). INFO:tensorflow:Assets written to: /tmp/tf_save/assets INFO:tensorflow:Assets written to: /tmp/tf_save/assets