XLA: コンパイラを機械学習用に最適化する

XLA(Accelerated Linear Algebra)は、線形代数のためのドメイン固有のコンパイラで、ソースコードを変更せずに TensorFlow モデルを高速化することができます。

XLA を使用すると、速度とメモリ使用量が改善します。たとえば BERT の場合、8 個の Volta V100 GPU(XLA を使用)を使った MLPerf の提出物で、パフォーマンスが約 7 倍、バッチサイズが約 5 倍改善されることが確認されています。

概要

TensorFlow プログラムを実行すると、すべてのオペレーションが TensorFlow エグゼキュータによって個別に実行されます。TensorFlow の各オペレーションは、エグゼキュータによってプリコンパイル済み GPU カーネル実装にディスパッチされます。

XLA では、モデルの実行を代替するモードが用意されています。代替モードでは、TensorFlow グラフが、所定のモデル用に生成された一連のコンピューティング カーネルにコンパイルされます。こうしたカーネルはモデルに固有であるため、モデル固有の情報を利用して最適化できます。TensorFlow で単純な演算を行う場合の XLA による最適化の例を次に示します。

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

XLA なしで実行する場合、乗算、加算、削減用として 3 つのカーネルがグラフで起動されます。XLA を使用すると、グラフが最適化され、1 回のカーネル起動で演算が行われます。これは、加算、乗算、削減を単一の GPU カーネルに「融合」することで行われます。さらに、この融合された演算では、y*zx+y*z で生成された中間値がメモリに書き出されません。代わりに、これらの中間演算の結果を GPU レジスタにすべて保持しながら、ユーザーに直接「ストリーミング」します。融合は、XLA の唯一かつ重要な最適化手法です。 ハードウェア アクセラレータにおいては一般にメモリ帯域幅のリソース上の制約が大きく、このようにメモリ操作を省くことが、パフォーマンスを改善するうえで有効です。

TensorFlow モデルに対して XLA を有効にする

tf.function(jit_compile=True) を使用した明示的なコンパイル

明示的コンパイル API を使用すると、コンパイルする関数を細かく制御できます。たとえば、次の TensorFlow 関数(MNIST トレーニングを実行する)は XLA でコンパイルされます。

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

jit_compile API には「コンパイル必須」セマンティクスがあります。これにより、関数全体が XLA でコンパイルされるか、errors.InvalidArgumentError 例外がスローされます。XLA では現在、次元を推定できない関数はコンパイルできません。つまり、計算全体を実行しないと推定できないテンソルの次元がある場合は、コンパイルできないということです。たとえば、次の関数はコンパイルできません。

@tf.function
def not_compilable(x):
  return tf.unique(x)

ただし、シェイプは実行ごとに変えることができます。

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

使用例について詳しくは、チュートリアルの Colabjit_compile=True の使用に関する チュートリアル動画 をご覧ください。

自動クラスタリング

TensorFlow モデルで何も変更を加えずに XLA を使用する簡単な方法は、自動クラスタリングを有効にすることです。自動クラスタリングにより、XLA を使用してコンパイルと実行が行える TensorFlow 関数内のクラスタ(連結サブグラフ)が自動的に検索されます。GPU での自動クラスタリングは、TF_XLA_FLAGS 環境変数を設定することで有効にできます。

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

現在、自動クラスタリングは GPU ワークロード用に最適化されていますが、次のように --tf_xla_cpu_global_jit フラグを追加することで、CPU に対して有効にすることもできます。

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

詳細な使用例については、自動クラスタリングに関するチュートリアルの Colab をご覧ください。

tfcompile による CPU の AOT(事前)コンパイル

スタンドアロンの tfcompile ツールを使用して、TensorFlow グラフを実行可能コード(x86-64 CPU のみ)に変換することもできます。

コンパイルされたプログラムの検査

XLA には、生成されたプログラムを検査できるイントロスペクション機能が用意されています。生成されたプログラムをダンプするには、環境変数 XLA_FLAGS を次のように使用します。

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

ダンプが行われると、/tmp/generated に次のファイルが生成されます。

  • module_XXXX.*_optimizations.txt: 生成された XLA プログラム(コンパイルされたクラスタごとに 1 つ)。XLA バグレポートの送信時に添付していただけるととても役立ちます。

  • module_XXXX.ir-*.ll: NVPTX を組み込むことで、LLVM による中間表現として生成されたファイル。

  • module_XXXX.ptx: 生成された PTX ファイル。

次のようにして、TensorFlow グラフ内に埋め込まれた XLA クラスタを可視化するグラフをダンプすることもできます。

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

再現可能なバグレポート

バグレポートに、生成された XLA プログラムと使用された自動クラスタリングの埋め込みのダンプが含まれていると、再現がはるかに簡単になります。自動クラスタリングを使用して実行する TensorFlow プログラムでこれらを生成するには、次のように起動します。

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

バグを報告する際は、前述の /tmp/generated ディレクトリの内容を添付してください。

可能であれば、replay_computation を使用して生成されたプログラムで繰り返し実行することで、バグを単一の XLA プログラムまで切り分けるようにしてください。

関連情報

XLA フロントエンド

XLA プログラムは、TensorFlow とは別に、次の手段でも生成できます。

  • JAX: Python+NumPy プログラムの変換構成ツール
  • Julia: 科学計算のためのプログラミング言語
  • PyTorch: PyTorch フレームワーク
  • Nx: Elixir プログラミング言語向け数値計算ライブラリ

講演

jit_compile=True を使用して、TF を通じて XLA を使用する

XLA の概要