tf.math.approx_max_k
Stay organized with collections
Save and categorize content based on your preferences.
Returns max k
values and their indices of the input operand
in an approximate manner.
tf.math.approx_max_k(
operand,
k,
reduction_dimension=-1,
recall_target=0.95,
reduction_input_size_override=-1,
aggregate_to_topk=True,
name=None
)
See https://arxiv.org/abs/2206.14286 for the algorithm details. This op is
only optimized on TPU currently.
Args |
operand
|
Array to search for max-k. Must be a floating number type.
|
k
|
Specifies the number of max-k.
|
reduction_dimension
|
Integer dimension along which to search. Default: -1.
|
recall_target
|
Recall target for the approximation.
|
reduction_input_size_override
|
When set to a positive value, it overrides
the size determined by operand[reduction_dim] for evaluating the recall.
This option is useful when the given operand is only a subset of the
overall computation in SPMD or distributed pipelines, where the true input
size cannot be deferred by the operand shape.
|
aggregate_to_topk
|
When true, aggregates approximate results to top-k. When
false, returns the approximate results. The number of the approximate
results is implementation defined and is greater equals to the specified
k .
|
name
|
Optional name for the operation.
|
Returns |
Tuple of two arrays. The arrays are the max k values and the
corresponding indices along the reduction_dimension of the input
operand . The arrays' dimensions are the same as the input operand
except for the reduction_dimension : when aggregate_to_topk is true,
the reduction dimension is k ; otherwise, it is greater equals to k
where the size is implementation-defined.
|
We encourage users to wrap approx_max_k
with jit. See the following
example for maximal inner production search (MIPS):
import tensorflow as tf
@tf.function(jit_compile=True)
def mips(qy, db, k=10, recall_target=0.95):
dists = tf.einsum('ik,jk->ij', qy, db)
# returns (f32[qy_size, k], i32[qy_size, k])
return tf.nn.approx_max_k(dists, k=k, recall_target=recall_target)
qy = tf.random.uniform((256,128))
db = tf.random.uniform((2048,128))
dot_products, neighbors = mips(qy, db, k=20)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2023-03-23 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-03-23 UTC."],[],[]]