tf.distribute.cluster_resolver.TPUClusterResolver

TensorFlow 1 version View source on GitHub

Cluster Resolver for Google Cloud TPUs.

Inherits From: ClusterResolver

This is an implementation of cluster resolvers for the Google Cloud TPU service.

TPUClusterResolver supports the following distinct environments: Google Compute Engine Google Kubernetes Engine Google internal

It can be passed into tf.distribute.TPUStrategy to support TF2 training on Cloud TPUs.

tpu A string corresponding to the TPU to use. It can be the TPU name or TPU worker gRPC address. If not set, it will try automatically resolve the TPU address on Cloud TPUs.
zone Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service.
project Name of the GCP project containing Cloud TPUs. If omitted or empty, we will try to discover the project name of the GCE VM from the GCE metadata service.
job_name Name of the TensorFlow job the TPUs belong to.
coordinator_name The name to use for the coordinator. Set to None if the coordinator should not be included in the computed ClusterSpec.
coordinator_address The address of the coordinator (typically an ip:port pair). If set to None, a TF server will be started. If coordinator_name is None, a TF server will not be started even if coordinator_address is None.
credentials GCE Credentials. If None, then we use default credentials from the oauth2client
service The GCE API object returned by the googleapiclient.discovery function. If you specify a custom service object, then the credentials parameter will be ignored.
discovery_url A URL template that points to the location of the discovery service. It should have two parameters {api} and {apiVersion} that when filled in produce an absolute URL to the discovery document for that service. The environment variable 'TPU_API_DISCOVERY_URL' will override this.

ImportError If the googleapiclient is not installed.
ValueError If no TPUs are specified.
RuntimeError If an empty TPU name is specified and this is running in a Google Cloud environment.

environment Returns the current environment which TensorFlow is running in.
task_id Returns the task id this ClusterResolver indicates.

In TensorFlow distributed environment, each job may have an applicable task id, which is the index of the instance within its task type. This is useful when user needs to run specific code according to task index. For example,

cluster_spec = tf.train.ClusterSpec({
"ps": ["localhost:2222", "localhost:2223"],
"worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
})

# SimpleClusterResolver is used here for illustration; other cluster
# resolvers may be used for other source of task type/id.
simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
task_id=0)

...

if cluster_resolver.task_type == 'worker' and cluster_resolver.task_id == 0:
# Perform something that's only applicable on 'worker' type, id 0. This
# block will run on this particular instance since we've specified this
# task to be a 'worker', id 0 in above cluster resolver.
else:
# Perform something that's only applicable on other ids. This block will
# not run on this particular instance.

Returns None if such information is not available or is not applicable in the current distributed environment, such as training with tf.distribute.cluster_resolver.TPUClusterResolver.

For more information, please see tf.distribute.cluster_resolver.ClusterResolver's class docstring.

task_type Returns the task type this ClusterResolver indicates.

In TensorFlow distributed environment, each job may have an applicable task type. Valid task types in TensorFlow include 'chief': a worker that is designated with more responsibility, 'worker': a regular worker for training/evaluation, 'ps': a parameter server, or 'evaluator': an evaluator that evaluates the checkpoints for metrics.

See Multi-worker configuration for more information about 'chief' and 'worker' task type, which are most commonly used.

Having access to such information is useful when user needs to run specific code according to task types. For example,

cluster_spec = tf.train.ClusterSpec({
"ps": ["localhost:2222", "localhost:2223"],
"worker": ["localhost:2224", "localhost:2225", "localhost:2226"]
})

# SimpleClusterResolver is used here for illustration; other cluster
# resolvers may be used for other source of task type/id.
simple_resolver = SimpleClusterResolver(cluster_spec, task_type="worker",
task_id=1)

...

if cluster_resolver.task_type == 'worker':
# Perform something that's only applicable on workers. This block
# will run on this particular instance since we've specified this task to
# be a worker in above cluster resolver.
elif cluster_resolver.task_type == 'ps':
# Perform something that's only applicable on parameter servers. This
# block will not run on this particular instance.

Returns None if such information is not available or is not applicable in the current distributed environment, such as training with tf.distribute.experimental.TPUStrategy.

For more information, please see tf.distribute.cluster_resolver.ClusterResolver's class doc.

Methods

cluster_spec

View source

Returns a ClusterSpec object based on the latest TPU information.

We retrieve the information from the GCE APIs every time this method is called.

Returns
A ClusterSpec containing host information returned from Cloud TPUs, or None.

Raises
RuntimeError If the provided TPU is not healthy.

connect

View source

Initializes TPU and returns a TPUClusterResolver.

This API will connect to remote TPU cluster and initialize the TPU hardwares. Example usage:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver.connect(
    tpu='')

It can be viewed as convenient wrapper of the following code:

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)

Args
tpu A string corresponding to the TPU to use. It can be the TPU name or TPU worker gRPC address. If not set, it will try automatically resolve the TPU address on Cloud TPUs.
zone Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service.
project Name of the GCP project containing Cloud TPUs. If omitted or empty, we will try to discover the project name of the GCE VM from the GCE metadata service.

Returns
An instance of TPUClusterResolver object.

Raises
NotFoundError If no TPU devices found in eager mode.

get_job_name

View source

get_master

View source

get_tpu_system_metadata

View source

Returns the metadata of the TPU system.

Users can call this method to get some facts of the TPU system, like total number of cores, number of TPU workers and the devices. E.g.


resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tpu_system_medata = resolver.get_tpu_system_metadata()
num_hosts = tpu_system_medata.num_hosts

Returns
A tf.tpu.experimental.TPUSystemMetadata object.

master

View source

Get the Master string to be used for the session.

In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of first instance in the ClusterSpec returned by the cluster_spec function.

If a non-TPU name is used when constructing a TPUClusterResolver, that will be returned instead (e.g. If the tpus argument's value when constructing this TPUClusterResolver was 'grpc://10.240.1.2:8470', 'grpc://10.240.1.2:8470' will be returned).

Args
task_type (Optional, string) The type of the TensorFlow task of the master.
task_id (Optional, integer) The index of the TensorFlow task of the master.
rpc_layer (Optional, string) The RPC protocol TensorFlow should use to communicate with TPUs.

Returns
string, the connection string to use when creating a session.

Raises
ValueError If none of the TPUs specified exists.

num_accelerators

View source

Returns the number of TPU cores per worker.

Connects to the master and list all the devices present in the master, and counts them up. Also verifies that the device counts per host in the cluster is the same before returning the number of TPU cores per host.

Args
task_type Unused.
task_id Unused.
config_proto Used to create a connection to a TPU master in order to retrieve the system metadata.

Raises
RuntimeError If we cannot talk to a TPU worker after retrying or if the number of TPU devices per host is different.

__enter__

View source

__exit__

View source