tf.keras.distribution.ModelParallel

Distribution that shards model variables.

Compare to DataParallel which replicates the variables across all devices, ModelParallel allows you to shard variables in addition to the input data.

To construct a ModelParallel distribution, you need to provide a DeviceMesh and a LayoutMap.

  1. DeviceMesh contains physical device information. The axis names in the mesh will be used to map the variable and data layout.
  2. LayoutMap contains the mapping between variable paths to their corresponding TensorLayout.

Example:

devices = list_devices()    # Assume there are 8 devices.

# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
                         devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)

distribution = ModelParallel(device_mesh=device_mesh,
                             layout_map=layout_map,
                             batch_dim_name='batch')
# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)

model = model_creation()
model.compile()
model.fit(data)

You can quickly update the device mesh shape to change the sharding factor of the variables. E.g.

# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'),
                         devices=devices)

To figure out a proper layout mapping rule for all the model variables, you can first list out all the model variable paths, which will be used as the key to map the variables to TensorLayout.

e.g.

model = create_model()
for v in model.variables:
    print(v.path)

device_mesh DeviceMesh instance for physical device and its logical mapping.
layout_map LayoutMap instance which map the variable path to the corresponding TensorLayout. The axis names of the TensorLayouts should match to the axis names in the device_mesh, or exception will be raised.
batch_dim_name optional string, the axis name in the device_mesh that will be used to distribute data. If unspecified, the first axis from the device_mesh will be used.

device_mesh

Methods

distribute_dataset

View source

Create a distributed dataset instance from the original user dataset.

Args
dataset the original global dataset instance. Only tf.data.Dataset is supported at the moment.

Returns
a sharded tf.data.Dataset instance, which will produce data for the current local worker/process.

get_data_layout

View source

Retrieve the TensorLayout for the input data.

Args
data_shape shape for the input data in list or tuple format.

Returns
The TensorLayout for the data, which can be used by backend.distribute_value() to redistribute a input data.

get_tensor_layout

View source

Retrieve the TensorLayout for the intermediate tensor.

Args
path a string path for the corresponding tensor.

return: The TensorLayout for the intermediate tensor, which can be used by backend.relayout() to reshard the tensor. Could also return None.

get_variable_layout

View source

Retrieve the TensorLayout for the variable.

Args
variable A KerasVariable instance.

return: The TensorLayout for the variable, which can be used by backend.distribute_value() to redistribute a variable.

scope

View source

Context manager to make the Distribution current.