View on TensorFlow.org | Run in Google Colab | View on GitHub | Download notebook | Run in Kaggle |
This example is based on Image classification via fine-tuning with EfficientNet to demonstrate how to train a NasNetMobile model using tensorflow_cloud and Google Cloud Platform at scale using distributed training.
Import required modules
import tensorflow as tf
tf.version.VERSION
'2.6.0'
! pip install -q tensorflow-cloud
import tensorflow_cloud as tfc
tfc.__version__
import sys
Project Configurations
Set project parameters. For Google Cloud Specific parameters refer to Google Cloud Project Setup Instructions.
# Set Google Cloud Specific parameters
# TODO: Please set GCP_PROJECT_ID to your own Google Cloud project ID.
GCP_PROJECT_ID = 'YOUR_PROJECT_ID'
# TODO: set GCS_BUCKET to your own Google Cloud Storage (GCS) bucket.
GCS_BUCKET = 'YOUR_GCS_BUCKET_NAME'
# DO NOT CHANGE: Currently only the 'us-central1' region is supported.
REGION = 'us-central1'
# OPTIONAL: You can change the job name to any string.
JOB_NAME = 'nasnet'
# Setting location were training logs and checkpoints will be stored
GCS_BASE_PATH = f'gs://{GCS_BUCKET}/{JOB_NAME}'
TENSORBOARD_LOGS_DIR = os.path.join(GCS_BASE_PATH,"logs")
MODEL_CHECKPOINT_DIR = os.path.join(GCS_BASE_PATH,"checkpoints")
SAVED_MODEL_DIR = os.path.join(GCS_BASE_PATH,"saved_model")
Authenticating the notebook to use your Google Cloud Project
For Kaggle Notebooks click on "Add-ons"->"Google Cloud SDK" before running the cell below.
# Using tfc.remote() to ensure this code only runs in notebook
if not tfc.remote():
# Authentication for Kaggle Notebooks
if "kaggle_secrets" in sys.modules:
from kaggle_secrets import UserSecretsClient
UserSecretsClient().set_gcloud_credentials(project=GCP_PROJECT_ID)
# Authentication for Colab Notebooks
if "google.colab" in sys.modules:
from google.colab import auth
auth.authenticate_user()
os.environ["GOOGLE_CLOUD_PROJECT"] = GCP_PROJECT_ID
Load and prepare data
Read raw data and split to train and test data sets.
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
# Setting input specific parameters
# The model expects input of dimension (INPUT_IMG_SIZE, INPUT_IMG_SIZE, 3)
INPUT_IMG_SIZE = 32
NUM_CLASSES = 10
Add preprocessing layers APIs for image augmentation.
from tensorflow.keras.layers.experimental import preprocessing
from tensorflow.keras.models import Sequential
img_augmentation = Sequential(
[
# Resizing input to better match ImageNet size
preprocessing.Resizing(256, 256),
preprocessing.RandomRotation(factor=0.15),
preprocessing.RandomFlip(),
preprocessing.RandomContrast(factor=0.1),
],
name="img_augmentation",
)
Load the model and prepare for training
We will load a NASNetMobile pretrained model (with weights) and unfreeze a few layers for fine tuning the model to better match the dataset.
from tensorflow.keras import layers
def build_model(num_classes, input_image_size):
inputs = layers.Input(shape=(input_image_size, input_image_size, 3))
x = img_augmentation(inputs)
model = tf.keras.applications.NASNetMobile(
input_shape=None,
include_top=False,
weights="imagenet",
input_tensor=x,
pooling=None,
classes=num_classes,
)
# Freeze the pretrained weights
model.trainable = False
# We unfreeze the top 20 layers while leaving BatchNorm layers frozen
for layer in model.layers[-20:]:
if not isinstance(layer, layers.BatchNormalization):
layer.trainable = True
# Rebuild top
x = layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
x = layers.BatchNormalization()(x)
x = layers.Dense(128, activation="relu")(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(num_classes, activation="softmax", name="pred")(x)
# Compile
model = tf.keras.Model(inputs, outputs, name="NASNetMobile")
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-4)
model.compile(
optimizer=optimizer,
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
return model
model = build_model(NUM_CLASSES, INPUT_IMG_SIZE)
if tfc.remote():
# Configure Tensorboard logs
callbacks=[
tf.keras.callbacks.TensorBoard(log_dir=TENSORBOARD_LOGS_DIR),
tf.keras.callbacks.ModelCheckpoint(
MODEL_CHECKPOINT_DIR,
save_best_only=True),
tf.keras.callbacks.EarlyStopping(
monitor='loss',
min_delta =0.001,
patience=3)]
model.fit(x=x_train, y=y_train, epochs=100,
validation_split=0.2, callbacks=callbacks)
model.save(SAVED_MODEL_DIR)
else:
# Run the training for 1 epoch and a small subset of the data to validate setup
model.fit(x=x_train[:100], y=y_train[:100], validation_split=0.2, epochs=1)
Start the remote training
This step will prepare your code from this notebook for remote execution and starts a distributed training remotely on Google Cloud Platform to train the model. Once the job is submitted you can go to the next step to monitor the jobs progress via Tensorboard.
# If you are using a custom image you can install modules via requirements
# txt file.
with open('requirements.txt','w') as f:
f.write('tensorflow-cloud\n')
# Optional: Some recommended base images. If you provide none the system
# will choose one for you.
TF_GPU_IMAGE= "tensorflow/tensorflow:latest-gpu"
TF_CPU_IMAGE= "tensorflow/tensorflow:latest"
# Submit a distributed training job using GPUs.
tfc.run(
distribution_strategy='auto',
requirements_txt='requirements.txt',
docker_config=tfc.DockerConfig(
parent_image=TF_GPU_IMAGE,
image_build_bucket=GCS_BUCKET
),
chief_config=tfc.COMMON_MACHINE_CONFIGS['K80_1X'],
worker_config=tfc.COMMON_MACHINE_CONFIGS['K80_1X'],
worker_count=3,
job_labels={'job': JOB_NAME}
)
Training Results
Reconnect your Colab instance
Most remote training jobs are long running. If you are using Colab, it may time out before the training results are available. In that case, rerun the following sections to reconnect and configure your Colab instance to access the training results. Run the following sections in order:
- Import required modules
- Project Configurations
- Authenticating the notebook to use your Google Cloud Project
Load Tensorboard
While the training is in progress you can use Tensorboard to view the results. Note the results will show only after your training has started. This may take a few minutes.
%load_ext tensorboard
%tensorboard --logdir $TENSORBOARD_LOGS_DIR
Load your trained model
trained_model = tf.keras.models.load_model(SAVED_MODEL_DIR)
trained_model.summary()