View source on GitHub |
Returns max k
values and their indices of the input operand
in an approximate manner.
tf.math.approx_max_k(
operand,
k,
reduction_dimension=-1,
recall_target=0.95,
reduction_input_size_override=-1,
aggregate_to_topk=True,
name=None
)
See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is only optimized on TPU currently.
Returns | |
---|---|
Tuple of two arrays. The arrays are the max k values and the
corresponding indices along the reduction_dimension of the input
operand . The arrays' dimensions are the same as the input operand
except for the reduction_dimension : when aggregate_to_topk is true,
the reduction dimension is k ; otherwise, it is greater equals to k
where the size is implementation-defined.
|
We encourage users to wrap approx_max_k
with jit. See the following
example for maximal inner production search (MIPS):
import tensorflow as tf
@tf.function(jit_compile=True)
def mips(qy, db, k=10, recall_target=0.95):
dists = tf.einsum('ik,jk->ij', qy, db)
# returns (f32[qy_size, k], i32[qy_size, k])
return tf.nn.approx_max_k(dists, k=k, recall_target=recall_target)
qy = tf.random.uniform((256,128))
db = tf.random.uniform((2048,128))
dot_products, neighbors = mips(qy, db, k=20)