tf.experimental.dtensor.Mesh

Represents a Mesh configuration over a certain list of Mesh Dimensions.

A mesh consists of named dimensions with sizes, which describe how a set of devices are arranged. Defining tensor layouts in terms of mesh dimensions allows us to efficiently determine the communication required when computing an operation with tensors of different layouts.

A mesh provides information not only about the placement of the tensors but also the topology of the underlying devices. For example, we can group 8 TPUs as a 1-D array for data parallelism or a 2x4 grid for (2-way) data parallelism and (4-way) model parallelism.

dim_names A list of strings indicating dimension names.
global_device_ids An ndarray of global device IDs is used to compose DeviceSpecs describing the mesh. The shape of this array determines the size of each mesh dimension. Values in this array should increment sequentially from 0. This argument is the same for every DTensor client.
local_device_ids A list of local device IDs equal to a subset of values in global_device_ids. They indicate the position of local devices in the global mesh. Different DTensor clients must contain distinct local_device_ids contents. All local_device_ids from all DTensor clients must cover every element in global_device_ids.
local_devices The list of devices hosted locally. The elements correspond 1:1 to those of local_device_ids.
mesh_name The name of the mesh. Currently, this is rarely used, and is mostly used to indicate whether it is a CPU, GPU, or TPU-based mesh.
global_devices optional

The list of global devices. Set when multiple device meshes are in use.

use_xla_spmd optional

Boolean when True, will use XLA SPMD instead of DTensor SPMD.

dim_names

name

size

strides Returns the strides tensor array for this mesh.

If the mesh shape is [a, b, c, d], then the strides array can be computed as [b*c*d, c*d, d, 1]. This array can be useful in computing local device offsets given a device ID. Using the same example, the device coordinates of the mesh can be computed as:

[(device_id / (b*c*d)) % a,
 (device_id / (c*d))   % b,
 (device_id / (d))     % c,
 (device_id)           % d]

This is the same as (device_id // mesh.strides) % mesh.shape.

Methods

as_proto

View source

Returns mesh protobuffer.

contains_dim

contains_dim(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> bool

Returns True if a Mesh contains the given dimension name.

coords

View source

Converts the device index into a tensor of mesh coordinates.

device_type

device_type(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> str

Returns the device_type of a Mesh.

dim_size

View source

Returns the size of a dimension.

from_proto

View source

Construct a mesh instance from input proto.

from_string

View source

Construct a mesh instance from input proto.

host_mesh

View source

Returns the 1-1 mapped host mesh.

is_remote

View source

Returns True if a Mesh contains only remote devices.

local_device_ids

View source

Returns a list of local device IDs.

local_device_locations

View source

Returns a list of local device locations.

A device location is a dictionary from dimension names to indices on those dimensions.

local_devices

View source

Returns a list of local device specs represented as strings.

min_global_device_id

View source

Returns the minimum global device ID.

num_local_devices

View source

Returns the number of local devices.

shape

View source

Returns the shape of the mesh.

to_string

to_string(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> str

Returns string representation of Mesh.

unravel_index

View source

Returns a dictionary from device ID to {dim_name: dim_index}.

For example, for a 3x2 mesh, return this:

  { 0: {'x': 0, 'y', 0},
    1: {'x': 0, 'y', 1},
    2: {'x': 1, 'y', 0},
    3: {'x': 1, 'y', 1},
    4: {'x': 2, 'y', 0},
    5: {'x': 2, 'y', 1} }

use_xla_spmd

use_xla_spmd(self: tensorflow.python._pywrap_dtensor_device.Mesh) -> bool

Returns True if Mesh will use XLA for SPMD instead of DTensor SPMD.

__contains__

contains(self: tensorflow.python._pywrap_dtensor_device.Mesh, dim_name: str) -> bool

__eq__

View source

Return self==value.

__getitem__

View source