View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
TensorFlowとXLAライブラリをインポートします。XLAには、一部または全てのモデルを XLA でコンパイルする実験的なAPIである xla.compile()
が含まれています。
import tensorflow as tf
from tensorflow.contrib.compiler import xla
必要ないくつかの定数を定義し、 MNISTのデータセットを用意します。
# それぞれの入力イメージの大きさは、 28 x 28ピクセル
IMAGE_SIZE = 28 * 28
# 個別の数字のラベル [0..9] の個数
NUM_CLASSES = 10
# それぞれのトレーニングバッチ(ステップ)での標本数
TRAIN_BATCH_SIZE = 100
# トレーニングステップを実行する回数
TRAIN_STEPS = 1000
# MNISTデータセットをロードする。
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)
iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)
images, labels = iterator.get_next()
images = tf.reshape(images, [-1, IMAGE_SIZE])
images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_types (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.data.get_output_types(dataset)`. WARNING:tensorflow:From <ipython-input-4-4b4b30f2fbb2>:5: DatasetV1.output_shapes (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.data.get_output_shapes(dataset)`.
モデルを構築する関数の定義
以下のコードブロックは、順伝搬と逆伝搬の両方を行う、1つのdenseレイヤーを持つ簡単なモデルを構築する関数を含みます。
コードが呼ばれたとき、2つの値を返します。 y
は、それぞれのターゲットのクラスの予測確率を表す tf.Tensor
です。 train_step
は global_step
の値を増加し、変数の更新を行う tf.Operation
です。
def build_mnist_model(x, y_):
y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
return y, train_step
XLA の有効化
XLA を有効化するには build_mnist_model
関数を xla.compile
に渡します。以下のコードブロックは、モデルを xla.compile()
関数でラップします。これにより、提供された入力を持つターゲット関数をXLAで実行できます。
[y] = xla.compile(build_mnist_model, inputs=[images, labels])
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:121: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where
グラフをコンパイルするとき、XLAはターゲット関数によって構築されたグラフの全てのノードを、いくつかのXLAのオペレータで置き換えます。
xla.compileは、生成されたXLAのオペレータから独立して実行できる tf.Operation
を返しません
代わりに、ターゲット関数から返された tf.Operation
ノードは、返された全ての tf.Tensor
の値との制御依存関係として追加されます。これにより、 返されたテンソルが評価されるときに、 tf.Operation
ノードの実行をトリガします。
擬似コードによるxla.compileの実装は、以下のようになります:
# TensorFlowに、XLAが扱いやすい方法でコードを実行するよう依頼する
y, train_step = build_mnist_model(images, labels)
with tf.control_dependencies([train_step]):
y = tf.identity(y)
# TensorFlowに、XLAが扱いやすい方法でコードの実行を停止するよう依頼する
xla.compile()は常に tf.Tensor
のリスト(1要素しか無かったとしても)を返します。
もしあなたが構築したグラフを今表示したら、通常のTensorFlowのグラフとそれほど変わらないことがわかり、前に述べたXLAのオペレータを見つけることができないでしょう。これは、あなたが sess.run()
でグラフを実行しようとしても、実際のコンパイルは後ほど発生するからです。後ほど、TensorFlowは実際にXLAオペレータを生成する一連のグラフ書き換えパスをトリガーします。これは、すべての入力がそろったときに、計算をコンパイルして実行します。
モデルの学習とテスト
# セッションを作成しすべての変数を初期化。
# xla.compile()は、Keras model.fit() APIやTF eager modeとはまだ動作しません。
sess = tf.Session()
sess.run(tf.global_variables_initializer())
以下のコードブロックはモデルを学習します。 y
の評価は、制御依存関係がある train_step
をトリガします。これは、モデル変数を更新します。
# 学習用データセットを与える
sess.run(iterator.make_initializer(train_ds))
# TRAIN_STEPS ステップだけ実行する
for i in range(TRAIN_STEPS):
sess.run(y)
print("Model trained for %s steps." % TRAIN_STEPS)
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:348: Iterator.output_types (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.data.get_output_types(iterator)`. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:349: Iterator.output_shapes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.data.get_output_shapes(iterator)`. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/iterator_ops.py:351: Iterator.output_classes (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.compat.v1.data.get_output_classes(iterator)`. Model trained for 1000 steps.
# 学習済みモデルをテストする
# テスト用データセットを与える
sess.run(iterator.make_initializer(test_ds))
# 精度を計算する
correct_prediction = tf.equal(tf.argmax(y, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % sess.run(accuracy))
Prediction accuracy after training: 0.91
# セッションを片付ける
sess.close()