XLA: Optimización del compilador para el aprendizaje automático

XLA (Accelerated Linear Algebra) es un compilador específico del área para álgebra lineal que puede acelerar los modelos de TensorFlow prácticamente sin cambiar el código fuente.

Los resultados son mejoras en la velocidad y el uso de memoria. Por ejemplo, en BERT, el envío de MLPerf con 8 GPU Volta V100 que usa XLA la mejora del rendimiento se multiplicó por 7 y la mejora del tamaño de lotes, por 5.

Introducción

Cuando se ejecuta un programa de TensorFlow, el ejecutor se encarga de todas las operaciones de forma individual. Cada operación de TensorFlow tiene una implementación de kernel de GPU previamente compilada a la que el ejecutor envía datos.

XLA proporciona un modo alternativo de ejecutar modelos: compila el grafo de TensorFlow en una secuencia de kernels de cálculo generados específicamente para el modelo determinado. Debido a que estos kernels son exclusivos del modelo, pueden aprovechar la información específica del modelo para la optimización. Por ejemplo, veamos una optimización que implementa XLA en el contexto de un cálculo simple de TensorFlow:

def model_fn(x, y, z):
  return tf.reduce_sum(x + y * z)

Si se ejecuta sin XLA, el grafo inicia tres kernels: uno para la multiplicación, uno para la suma y otro para la resta. Sin embargo, XLA puede optimizar el grafo para que calcule el resultado en un solo proceso de kernel. Para ello, "fusiona" la suma, la multiplicación y la resta en un solo kernel de GPU. Además, esta operación combinada no escribe los valores intermedios que generan y*z y x+y*z en la memoria. En cambio, "transmite" los resultados de estos cálculos intermedios directamente a los usuarios y los mantiene por completo en los registros de la GPU. La fusión es la optimización más importante de XLA. Por lo general, el ancho de banda de la memoria es el recurso más escaso en los aceleradores de hardware, por lo tanto, quitar las operaciones de memoria es una de las mejores formas de mejorar el rendimiento.

Habilita XLA para modelos de TensorFlow

Compilación explícita con tf.function(jit_compile=True)

La API de compilación explícita ofrece un control más detallado para elegir qué funciones se deben compilar. Por ejemplo, la siguiente función de TensorFlow, que realiza el entrenamiento de MNIST, se compila con XLA:

@tf.function(jit_compile=True)
def train_mnist(images, labels):
    images, labels = cast(images, labels)

    with tf.GradientTape() as tape:
      predicted_labels = layer(images)
      loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
          logits=predicted_labels, labels=labels
      ))
    layer_variables = layer.trainable_variables
    grads = tape.gradient(loss, layer_variables)
    optimizer.apply_gradients(zip(grads, layer_variables))

La API de jit_compile tiene una semántica se debe compilar: se compila toda la función con XLA o se genera una excepción errors.InvalidArgumentError. Actualmente, XLA no puede compilar funciones en las que las dimensiones no se puedan inferir, es decir, cuando no es posible inferir las dimensiones de todos los tensores sin ejecutar todo el cálculo. Por ejemplo, la siguiente función no se compilará:

@tf.function
def not_compilable(x):
  return tf.unique(x)

Pero las formas pueden variar en las ejecuciones:

@tf.function(jit_compile=True)
def recompiled_on_launch(a, b):
  return a + b

recompiled_on_launch(tf.ones([1, 10]), tf.ones([1, 10]))
recompiled_on_launch(tf.ones([1, 100]), tf.ones([1, 100]))

Consulta el instructivo de Colab para ver un ejemplo de uso más detallado y un video instructivo sobre el jit_compile=True uso.

Agrupamiento en clústeres automático

Una manera simple de usar XLA en modelos de TensorFlow sin realizar ningún cambio es habilitar el agrupamiento en clústeres automático, que encuentra automáticamente los clústeres (subgrafos conectados) dentro de las funciones de TensorFlow que se pueden compilar y ejecutar mediante XLA. El agrupamiento en clústeres automático en la GPU se puede habilitar mediante la configuración de la variable de entorno TF_XLA_FLAGS:

$ TF_XLA_FLAGS=--tf_xla_auto_jit=2 path/to/your/tf/program

Actualmente, el agrupamiento en clústeres automático está optimizado para cargas de trabajo de GPU, pero también se puede habilitar en la CPU si se usa la marca --tf_xla_cpu_global_jit:

$ TF_XLA_FLAGS="--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit" path/to/your/program

Para obtener un ejemplo de uso detallado, consulta el instructivo de agrupamiento en clústeres automático de Colab.

Compilación AOT (anticipada) para CPU con tfcompile

También puedes usar una herramienta tfcompile independiente, que convierte el grafo de TensorFlow en código ejecutable (solo para CPU x86-64).

Cómo inspeccionar los programas compilados

XLA proporciona recursos de introspección que te permiten inspeccionar los programas generados. Para volcar los programas generados, usa la variable de entorno XLA_FLAGS:

$ XLA_FLAGS="--xla_dump_to=/tmp/generated" TF_XLA_FLAGS="--tf_xla_auto_jit=2" my/tensorflow/program

Después de realizar el volcado, podrás encontrar los siguientes archivos en /tmp/generated:

  • module_XXXX.*_optimizations.txt Se generaron programas de XLA, uno por cada clúster compilado. Es muy útil adjuntarlos cuando se envían los informes de errores de XLA.

  • module_XXXX.ir-*.ll Se generaron archivos en la representación intermedia de LLVM, con funciones intrínsecas de NVPTX.

  • module_XXXX.ptx Se generaron archivos PTX.

También puedes volcar el grafo si visualizas la incorporación de clústeres de XLA dentro del grafo de TensorFlow con lo siguiente:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated TF_XLA_FLAGS="--tf_xla_clustering_debug"

Informes de errores reproducibles

Un informe de errores es mucho más fácil de reproducir si incluye volcados para los programas XLA generados y la incorporación utilizada para el agrupamiento en clústeres automático. Para generarlos en un programa de TensorFlow que se ejecute con agrupamiento en clústeres automático, inicia lo siguiente:

$ TF_DUMP_GRAPH_PREFIX=/tmp/generated \
  TF_XLA_FLAGS="--tf_xla_clustering_debug --tf_xla_auto_jit=2" \
  XLA_FLAGS="--xla_dump_hlo_as_text --xla_dump_to=/tmp/generated" \
    my/tensorflow/program"

Cuando envíes errores, adjunta el contenido del directorio de /tmp/generated (que se muestra más arriba).

Si es posible, intenta aislar un error en un solo programa XLA con replay_computation y ejecútalo de forma iterativa en los programas generados.

Lecturas adicionales

Frontends de XLA

Además de hacerlo con TensorFlow, los programas XLA se pueden generar también con:

  • JAX: Transformaciones de programas Python+NumPy que se pueden componer
  • Julia: El lenguaje Julia para el procesamiento científico
  • PyTorch: Framework de PyTorch
  • Nx: Biblioteca de procesamiento numérico para el lenguaje de programación Elixir

Charlas

Uso de XLA desde TF mediante jit_compile=True

Descripción general de XLA