View on TensorFlow.org | Run in Google Colab | View source on GitHub |
This tutorial trains a TensorFlow model to classify the CIFAR-10 dataset, and we compile it using XLA.
Load and normalize the dataset using the Keras API:
import tensorflow as tf
# Check that GPU is available: cf. https://colab.research.google.com/notebooks/gpu.ipynb
assert(tf.test.gpu_device_name())
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(False) # Start with XLA disabled.
def load_data():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32') / 256
x_test = x_test.astype('float32') / 256
# Convert class vectors to binary class matrices.
y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)
return ((x_train, y_train), (x_test, y_test))
(x_train, y_train), (x_test, y_test) = load_data()
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz 170500096/170498071 [==============================] - 2s 0us/step 170508288/170498071 [==============================] - 2s 0us/step
Keras CIFAR-10 예제를 기초로 모델을 정의합니다.
def generate_model():
return tf.keras.models.Sequential([
tf.keras.layers.Conv2D(32, (3, 3), padding='same', input_shape=x_train.shape[1:]),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(32, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Conv2D(64, (3, 3), padding='same'),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Conv2D(64, (3, 3)),
tf.keras.layers.Activation('relu'),
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512),
tf.keras.layers.Activation('relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(10),
tf.keras.layers.Activation('softmax')
])
model = generate_model()
RMSprop 옵티마이저를 사용하여 모델을 훈련합니다.
def compile_model(model):
opt = tf.keras.optimizers.RMSprop(lr=0.0001, decay=1e-6)
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model
model = compile_model(model)
def train_model(model, x_train, y_train, x_test, y_test, epochs=25):
model.fit(x_train, y_train, batch_size=256, epochs=epochs, validation_data=(x_test, y_test), shuffle=True)
def warmup(model, x_train, y_train, x_test, y_test):
# Warm up the JIT, we do not wish to measure the compilation time.
initial_weights = model.get_weights()
train_model(model, x_train, y_train, x_test, y_test, epochs=1)
model.set_weights(initial_weights)
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/optimizer_v2/optimizer_v2.py:356: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead. "The `lr` argument is deprecated, use `learning_rate` instead.") 196/196 [==============================] - 4s 11ms/step - loss: 2.0395 - accuracy: 0.2455 - val_loss: 1.8326 - val_accuracy: 0.3586 Epoch 1/25 196/196 [==============================] - 2s 9ms/step - loss: 2.1007 - accuracy: 0.2191 - val_loss: 1.8853 - val_accuracy: 0.3260 Epoch 2/25 196/196 [==============================] - 2s 8ms/step - loss: 1.7937 - accuracy: 0.3516 - val_loss: 1.6568 - val_accuracy: 0.4097 Epoch 3/25 196/196 [==============================] - 2s 8ms/step - loss: 1.6670 - accuracy: 0.3938 - val_loss: 1.5464 - val_accuracy: 0.4434 Epoch 4/25 196/196 [==============================] - 1s 8ms/step - loss: 1.5847 - accuracy: 0.4243 - val_loss: 1.5004 - val_accuracy: 0.4614 Epoch 5/25 196/196 [==============================] - 2s 8ms/step - loss: 1.5202 - accuracy: 0.4511 - val_loss: 1.4879 - val_accuracy: 0.4710 Epoch 6/25 196/196 [==============================] - 2s 8ms/step - loss: 1.4666 - accuracy: 0.4723 - val_loss: 1.3830 - val_accuracy: 0.5013 Epoch 7/25 196/196 [==============================] - 1s 8ms/step - loss: 1.4273 - accuracy: 0.4845 - val_loss: 1.3415 - val_accuracy: 0.5174 Epoch 8/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3935 - accuracy: 0.4995 - val_loss: 1.3505 - val_accuracy: 0.5193 Epoch 9/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3599 - accuracy: 0.5139 - val_loss: 1.2670 - val_accuracy: 0.5488 Epoch 10/25 196/196 [==============================] - 2s 8ms/step - loss: 1.3300 - accuracy: 0.5268 - val_loss: 1.2622 - val_accuracy: 0.5519 Epoch 11/25 196/196 [==============================] - 1s 8ms/step - loss: 1.3020 - accuracy: 0.5390 - val_loss: 1.2196 - val_accuracy: 0.5689 Epoch 12/25 196/196 [==============================] - 1s 8ms/step - loss: 1.2768 - accuracy: 0.5469 - val_loss: 1.1969 - val_accuracy: 0.5762 Epoch 13/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2519 - accuracy: 0.5558 - val_loss: 1.2510 - val_accuracy: 0.5621 Epoch 14/25 196/196 [==============================] - 1s 8ms/step - loss: 1.2337 - accuracy: 0.5644 - val_loss: 1.1758 - val_accuracy: 0.5872 Epoch 15/25 196/196 [==============================] - 2s 8ms/step - loss: 1.2085 - accuracy: 0.5738 - val_loss: 1.1580 - val_accuracy: 0.5941 Epoch 16/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1831 - accuracy: 0.5841 - val_loss: 1.1352 - val_accuracy: 0.6045 Epoch 17/25 196/196 [==============================] - 1s 8ms/step - loss: 1.1627 - accuracy: 0.5897 - val_loss: 1.1194 - val_accuracy: 0.6086 Epoch 18/25 196/196 [==============================] - 1s 8ms/step - loss: 1.1435 - accuracy: 0.5948 - val_loss: 1.1733 - val_accuracy: 0.5908 Epoch 19/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1218 - accuracy: 0.6072 - val_loss: 1.0623 - val_accuracy: 0.6298 Epoch 20/25 196/196 [==============================] - 2s 8ms/step - loss: 1.1053 - accuracy: 0.6113 - val_loss: 1.0589 - val_accuracy: 0.6335 Epoch 21/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0872 - accuracy: 0.6198 - val_loss: 1.0317 - val_accuracy: 0.6380 Epoch 22/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0719 - accuracy: 0.6252 - val_loss: 1.0427 - val_accuracy: 0.6369 Epoch 23/25 196/196 [==============================] - 2s 8ms/step - loss: 1.0527 - accuracy: 0.6308 - val_loss: 0.9874 - val_accuracy: 0.6531 Epoch 24/25 196/196 [==============================] - 1s 8ms/step - loss: 1.0351 - accuracy: 0.6361 - val_loss: 1.0075 - val_accuracy: 0.6492 Epoch 25/25 196/196 [==============================] - 1s 8ms/step - loss: 1.0236 - accuracy: 0.6436 - val_loss: 0.9797 - val_accuracy: 0.6587 CPU times: user 47 s, sys: 5.72 s, total: 52.7 s Wall time: 39.4 s 313/313 [==============================] - 1s 2ms/step - loss: 0.9797 - accuracy: 0.6587 Test loss: 0.9796982407569885 Test accuracy: 0.6586999893188477
이제 XLA 컴파일러를 사용하여 모델을 다시 훈련하겠습니다. 애플리케이션 중간에 컴파일러를 활성화하려면 Keras 세션을 재설정해야 합니다.
# We need to clear the session to enable JIT in the middle of the program.
tf.keras.backend.clear_session()
tf.config.optimizer.set_jit(True) # Enable XLA.
model = compile_model(generate_model())
(x_train, y_train), (x_test, y_test) = load_data()
warmup(model, x_train, y_train, x_test, y_test)
%time train_model(model, x_train, y_train, x_test, y_test)
196/196 [==============================] - 5s 13ms/step - loss: 2.0439 - accuracy: 0.2498 - val_loss: 1.8283 - val_accuracy: 0.3566 Epoch 1/25 196/196 [==============================] - 4s 18ms/step - loss: 2.1271 - accuracy: 0.2144 - val_loss: 1.8623 - val_accuracy: 0.3491 Epoch 2/25 196/196 [==============================] - 1s 7ms/step - loss: 1.8081 - accuracy: 0.3496 - val_loss: 1.6823 - val_accuracy: 0.4058 Epoch 3/25 196/196 [==============================] - 1s 7ms/step - loss: 1.6905 - accuracy: 0.3908 - val_loss: 1.5872 - val_accuracy: 0.4324 Epoch 4/25 196/196 [==============================] - 1s 7ms/step - loss: 1.6168 - accuracy: 0.4183 - val_loss: 1.5310 - val_accuracy: 0.4419 Epoch 5/25 196/196 [==============================] - 1s 7ms/step - loss: 1.5570 - accuracy: 0.4401 - val_loss: 1.4528 - val_accuracy: 0.4819 Epoch 6/25 196/196 [==============================] - 1s 7ms/step - loss: 1.5004 - accuracy: 0.4583 - val_loss: 1.4114 - val_accuracy: 0.4932 Epoch 7/25 196/196 [==============================] - 1s 7ms/step - loss: 1.4591 - accuracy: 0.4765 - val_loss: 1.3647 - val_accuracy: 0.5160 Epoch 8/25 196/196 [==============================] - 1s 7ms/step - loss: 1.4189 - accuracy: 0.4897 - val_loss: 1.3653 - val_accuracy: 0.5151 Epoch 9/25 196/196 [==============================] - 1s 7ms/step - loss: 1.3828 - accuracy: 0.5049 - val_loss: 1.3127 - val_accuracy: 0.5288 Epoch 10/25 196/196 [==============================] - 1s 7ms/step - loss: 1.3481 - accuracy: 0.5168 - val_loss: 1.3534 - val_accuracy: 0.5285 Epoch 11/25 196/196 [==============================] - 1s 7ms/step - loss: 1.3209 - accuracy: 0.5288 - val_loss: 1.2366 - val_accuracy: 0.5606 Epoch 12/25 196/196 [==============================] - 1s 7ms/step - loss: 1.2934 - accuracy: 0.5397 - val_loss: 1.2379 - val_accuracy: 0.5622 Epoch 13/25 196/196 [==============================] - 1s 7ms/step - loss: 1.2630 - accuracy: 0.5498 - val_loss: 1.2640 - val_accuracy: 0.5523 Epoch 14/25 196/196 [==============================] - 1s 7ms/step - loss: 1.2403 - accuracy: 0.5584 - val_loss: 1.2333 - val_accuracy: 0.5618 Epoch 15/25 196/196 [==============================] - 1s 7ms/step - loss: 1.2169 - accuracy: 0.5699 - val_loss: 1.1787 - val_accuracy: 0.5851 Epoch 16/25 196/196 [==============================] - 1s 7ms/step - loss: 1.1944 - accuracy: 0.5809 - val_loss: 1.1339 - val_accuracy: 0.5962 Epoch 17/25 196/196 [==============================] - 1s 7ms/step - loss: 1.1746 - accuracy: 0.5850 - val_loss: 1.1283 - val_accuracy: 0.6029 Epoch 18/25 196/196 [==============================] - 1s 7ms/step - loss: 1.1515 - accuracy: 0.5963 - val_loss: 1.1050 - val_accuracy: 0.6090 Epoch 19/25 196/196 [==============================] - 1s 7ms/step - loss: 1.1324 - accuracy: 0.6014 - val_loss: 1.0778 - val_accuracy: 0.6210 Epoch 20/25 196/196 [==============================] - 1s 7ms/step - loss: 1.1116 - accuracy: 0.6091 - val_loss: 1.1027 - val_accuracy: 0.6124 Epoch 21/25 196/196 [==============================] - 1s 7ms/step - loss: 1.0944 - accuracy: 0.6158 - val_loss: 1.0454 - val_accuracy: 0.6356 Epoch 22/25 196/196 [==============================] - 1s 7ms/step - loss: 1.0744 - accuracy: 0.6225 - val_loss: 1.0302 - val_accuracy: 0.6387 Epoch 23/25 196/196 [==============================] - 1s 7ms/step - loss: 1.0569 - accuracy: 0.6280 - val_loss: 1.0352 - val_accuracy: 0.6383 Epoch 24/25 196/196 [==============================] - 1s 7ms/step - loss: 1.0431 - accuracy: 0.6328 - val_loss: 0.9780 - val_accuracy: 0.6603 Epoch 25/25 196/196 [==============================] - 1s 7ms/step - loss: 1.0278 - accuracy: 0.6406 - val_loss: 0.9785 - val_accuracy: 0.6616 CPU times: user 39.4 s, sys: 5.3 s, total: 44.7 s Wall time: 39.9 s
Titan V GPU 및 Intel Xeon E5-2690 CPU를 탑재한 시스템에서 속도 향상은 약 1.17배입니다.