tf.nn.embedding_lookup

TensorFlow 1 version View source on GitHub

Looks up embeddings for the given ids from a list of tensors.

This function is used to perform parallel lookups on the list of tensors in params. It is a generalization of tf.gather, where params is interpreted as a partitioning of a large embedding tensor.

If len(params) > 1, each element id of ids is partitioned between the elements of params according to the "div" partition strategy, which means we assign ids to partitions in a contiguous manner. For instance, 13 ids are split across 5 partitions as: [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]].

If the id space does not evenly divide the number of partitions, each of the first (max_id + 1) % len(params) partitions will be assigned one more id.

The results of the lookup are concatenated into a dense tensor. The returned tensor has shape shape(ids) + shape(params)[1:].

params A single tensor representing the complete embedding tensor, or a list of tensors all of same shape except for the first dimension, representing sharded embedding tensors following "div" partition strategy.
ids A Tensor with type int32 or int64 containing the ids to be looked up in params.
max_norm If not None, each embedding is clipped if its l2-norm is larger than this value.
name A name for the operation (optional).

A Tensor with the same type as the tensors in params.

For instance, if params is a 5x2 matrix:

[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]

or a list of matrices:

params[0]: [[1, 2], [3, 4]]
params[1]: [[5, 6], [7, 8]]
params[2]: [[9, 10]]

and ids is:

[0, 3, 4]

The output will be a 3x2 matrix:

[[1, 2], [7, 8], [9, 10]]

ValueError If params is empty.