View source on GitHub |
Distribution that shards model variables.
tf.keras.distribution.ModelParallel(
device_mesh, layout_map, batch_dim_name=None
)
Compare to DataParallel
which replicates the variables across all devices,
ModelParallel
allows you to shard variables in addition to the input data.
To construct a ModelParallel
distribution, you need to provide a
DeviceMesh
and a LayoutMap
.
DeviceMesh
contains physical device information. The axis names in the mesh will be used to map the variable and data layout.LayoutMap
contains the mapping between variable paths to their correspondingTensorLayout
.
Example:
devices = list_devices() # Assume there are 8 devices.
# Create a mesh with 2 devices for data parallelism and 4 devices for
# model parallelism.
device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),
devices=devices)
# Create a layout map that shard the `Dense` layer and `Conv2D`
# layer variables on the last dimension.
# Based on the `device_mesh`, this means the variables
# will be split across 4 devices. Any other variable that doesn't
# match any key in the layout map will be fully replicated.
layout_map = LayoutMap(device_mesh)
layout_map['dense.*kernel'] = (None, 'model')
layout_map['dense.*bias'] = ('model',)
layout_map['conv2d.*kernel'] = (None, None, None, 'model')
layout_map['conv2d.*bias'] = ('model',)
distribution = ModelParallel(device_mesh=device_mesh,
layout_map=layout_map,
batch_dim_name='batch')
# Set the global distribution, or via `with distribution.scope():`
set_distribution(distribution)
model = model_creation()
model.compile()
model.fit(data)
You can quickly update the device mesh shape to change the sharding factor of the variables. E.g.
# With only the shape change for the device mesh, the variables will be
# sharded across 8 devices instead of 4, which further reduces the memory
# footprint of variables on each of the device.
device_mesh = DeviceMesh(shape=(1, 8), axis_names=('batch', 'model'),
devices=devices)
To figure out a proper layout mapping rule for all the model variables, you
can first list out all the model variable paths, which will be used as the
key to map the variables to TensorLayout
.
e.g.
model = create_model()
for v in model.variables:
print(v.path)
Attributes | |
---|---|
device_mesh
|
Methods
distribute_dataset
distribute_dataset(
dataset
)
Create a distributed dataset instance from the original user dataset.
Args | |
---|---|
dataset
|
the original global dataset instance. Only
tf.data.Dataset is supported at the moment.
|
Returns | |
---|---|
a sharded tf.data.Dataset instance, which will produce data for
the current local worker/process.
|
get_data_layout
get_data_layout(
data_shape
)
Retrieve the TensorLayout
for the input data.
Args | |
---|---|
data_shape
|
shape for the input data in list or tuple format. |
Returns | |
---|---|
The TensorLayout for the data, which can be used by
backend.distribute_value() to redistribute a input data.
|
get_tensor_layout
get_tensor_layout(
path
)
Retrieve the TensorLayout
for the intermediate tensor.
Args | |
---|---|
path
|
a string path for the corresponding tensor. |
return:
The TensorLayout
for the intermediate tensor, which can be used
by backend.relayout()
to reshard the tensor. Could also return
None.
get_variable_layout
get_variable_layout(
variable
)
Retrieve the TensorLayout
for the variable.
Args | |
---|---|
variable
|
A KerasVariable instance.
|
return:
The TensorLayout
for the variable, which can be used by
backend.distribute_value()
to redistribute a variable.
scope
@contextlib.contextmanager
scope()
Context manager to make the Distribution
current.