在 TensorFlow.org 上查看 | 在 Google Colab 中运行 | 在 GitHub 上查看源代码 | 下载笔记本 |
使用包含大量零值的张量时,务必以节省空间和时间的方式存储它们。稀疏张量可以高效存储和处理包含大量零值的张量。稀疏张量广泛用于 TF-IDF 等编码方案,作为 NLP 应用中数据预处理的一部分,以及在计算机视觉应用中预处理具有大量暗像素的图像。
TensorFlow 中的稀疏张量
TensorFlow 通过 tf.sparse.SparseTensor
对象表示稀疏张量。目前,TensorFlow 中的稀疏张量使用坐标列表 (COO) 格式进行编码。这种编码格式针对嵌入向量等超稀疏矩阵进行了优化。
稀疏张量的 COO 编码包括:
values
:形状为[N]
的一维张量,包含所有非零值。indices
:形状为[N, rank]
的二维张量,包含非零值的索引。dense_shape
:形状为[rank]
的一维张量,指定张量的形状。
tf.sparse.SparseTensor
上下文中的非零值是未显式编码的值。可以在 COO 稀疏矩阵的 values
中显式包含零值,但在稀疏张量中引用非零值时,通常不包含这些“显式零”。
注:tf.sparse.SparseTensor
不要求索引/值按任何特定顺序排列,但一些运算假定它们按行优先顺序排列。使用 tf.sparse.reorder
创建按规范行优先顺序排序的稀疏张量副本。
创建 tf.sparse.SparseTensor
通过直接指定它们的 values
、indices
和 dense_shape
来构造稀疏张量。
import tensorflow as tf
2022-12-14 22:51:34.880346: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:51:34.880458: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 22:51:34.880468: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
st1 = tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]],
values=[10, 20],
dense_shape=[3, 10])
当您使用 print()
函数打印一个稀疏张量时,它会显示三个张量分量的内容:
print(st1)
SparseTensor(indices=tf.Tensor( [[0 3] [2 4]], shape=(2, 2), dtype=int64), values=tf.Tensor([10 20], shape=(2,), dtype=int32), dense_shape=tf.Tensor([ 3 10], shape=(2,), dtype=int64))
如果非零 values
与其对应的 indices
对齐,则更容易理解稀疏张量的内容。定义一个辅助函数来以美观的格式打印稀疏张量,以便每个非零值都显示在自己的行上。
def pprint_sparse_tensor(st):
s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
for (index, value) in zip(st.indices, st.values):
s += f"\n %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
return s + "}>"
print(pprint_sparse_tensor(st1))
<SparseTensor shape=[3, 10] values={ [0, 3]: 10 [2, 4]: 20}>
您还可以使用 tf.sparse.from_dense
从密集张量构造稀疏张量,并使用 tf.sparse.to_dense
将它们转换回密集张量。
st2 = tf.sparse.from_dense([[1, 0, 0, 8], [0, 0, 0, 0], [0, 0, 3, 0]])
print(pprint_sparse_tensor(st2))
<SparseTensor shape=[3, 4] values={ [0, 0]: 1 [0, 3]: 8 [2, 2]: 3}>
st3 = tf.sparse.to_dense(st2)
print(st3)
tf.Tensor( [[1 0 0 8] [0 0 0 0] [0 0 3 0]], shape=(3, 4), dtype=int32)
操纵稀疏张量
使用 tf.sparse
软件包中的实用工具来操纵稀疏张量。像 tf.math.add
这样可用于密集张量的算术操纵的运算不适用于稀疏张量。
使用 tf.sparse.add
添加相同形状的稀疏张量。
st_a = tf.sparse.SparseTensor(indices=[[0, 2], [3, 4]],
values=[31, 2],
dense_shape=[4, 10])
st_b = tf.sparse.SparseTensor(indices=[[0, 2], [7, 0]],
values=[56, 38],
dense_shape=[4, 10])
st_sum = tf.sparse.add(st_a, st_b)
print(pprint_sparse_tensor(st_sum))
<SparseTensor shape=[4, 10] values={ [0, 2]: 87 [3, 4]: 2 [7, 0]: 38}>
使用 tf.sparse.sparse_dense_matmul
将稀疏张量与密集矩阵相乘。
st_c = tf.sparse.SparseTensor(indices=([0, 1], [1, 0], [1, 1]),
values=[13, 15, 17],
dense_shape=(2,2))
mb = tf.constant([[4], [6]])
product = tf.sparse.sparse_dense_matmul(st_c, mb)
print(product)
tf.Tensor( [[ 78] [162]], shape=(2, 1), dtype=int32)
使用 tf.sparse.concat
将稀疏张量放在一起,使用 tf.sparse.slice
将它们分开。
sparse_pattern_A = tf.sparse.SparseTensor(indices = [[2,4], [3,3], [3,4], [4,3], [4,4], [5,4]],
values = [1,1,1,1,1,1],
dense_shape = [8,5])
sparse_pattern_B = tf.sparse.SparseTensor(indices = [[0,2], [1,1], [1,3], [2,0], [2,4], [2,5], [3,5],
[4,5], [5,0], [5,4], [5,5], [6,1], [6,3], [7,2]],
values = [1,1,1,1,1,1,1,1,1,1,1,1,1,1],
dense_shape = [8,6])
sparse_pattern_C = tf.sparse.SparseTensor(indices = [[3,0], [4,0]],
values = [1,1],
dense_shape = [8,6])
sparse_patterns_list = [sparse_pattern_A, sparse_pattern_B, sparse_pattern_C]
sparse_pattern = tf.sparse.concat(axis=1, sp_inputs=sparse_patterns_list)
print(tf.sparse.to_dense(sparse_pattern))
tf.Tensor( [[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0] [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0] [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0] [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0] [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0] [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]], shape=(8, 17), dtype=int32)
sparse_slice_A = tf.sparse.slice(sparse_pattern_A, start = [0,0], size = [8,5])
sparse_slice_B = tf.sparse.slice(sparse_pattern_B, start = [0,5], size = [8,6])
sparse_slice_C = tf.sparse.slice(sparse_pattern_C, start = [0,10], size = [8,6])
print(tf.sparse.to_dense(sparse_slice_A))
print(tf.sparse.to_dense(sparse_slice_B))
print(tf.sparse.to_dense(sparse_slice_C))
tf.Tensor( [[0 0 0 0 0] [0 0 0 0 0] [0 0 0 0 1] [0 0 0 1 1] [0 0 0 1 1] [0 0 0 0 1] [0 0 0 0 0] [0 0 0 0 0]], shape=(8, 5), dtype=int32) tf.Tensor( [[0] [0] [1] [1] [1] [1] [0] [0]], shape=(8, 1), dtype=int32) tf.Tensor([], shape=(8, 0), dtype=int32)
如果使用的是 TensorFlow 2.4 或更高版本,请使用 tf.sparse.map_values
对稀疏张量中的非零值执行逐元素运算。
st2_plus_5 = tf.sparse.map_values(tf.add, st2, 5)
print(tf.sparse.to_dense(st2_plus_5))
tf.Tensor( [[ 6 0 0 13] [ 0 0 0 0] [ 0 0 8 0]], shape=(3, 4), dtype=int32)
请注意,仅修改了非零值 – 零值保持为零。
同样,可以遵循下方 TensorFlow 早期版本的设计模式:
st2_plus_5 = tf.sparse.SparseTensor(
st2.indices,
st2.values + 5,
st2.dense_shape)
print(tf.sparse.to_dense(st2_plus_5))
tf.Tensor( [[ 6 0 0 13] [ 0 0 0 0] [ 0 0 8 0]], shape=(3, 4), dtype=int32)
将 tf.sparse.SparseTensor
与其他 TensorFlow API 一起使用
稀疏张量透明地与这些 TensorFlow API 一起使用:
tf.keras
tf.data
tf.Train.Example
protobuftf.function
tf.while_loop
tf.cond
tf.identity
tf.cast
tf.print
tf.saved_model
tf.io.serialize_sparse
tf.io.serialize_many_sparse
tf.io.deserialize_many_sparse
tf.math.abs
tf.math.negative
tf.math.sign
tf.math.square
tf.math.sqrt
tf.math.erf
tf.math.tanh
tf.math.bessel_i0e
tf.math.bessel_i1e
下面显示了上述 API 的一些示例。
tf.keras
tf.keras
API 的一个子集支持稀疏张量,无需执行开销较大的类型转换或转换运算。可以利用 Keras API 将稀疏张量作为输入传递给 Keras 模型。调用 tf.keras.Input
或 tf.keras.layers.InputLayer
时设置 sparse=True
。您可以在 Keras 层之间传递稀疏张量,也可以让 Keras 模型将它们作为输出返回。如果在模型中的 tf.keras.layers.Dense
层中使用稀疏张量,它们将输出密集张量。
下面的示例展示了如果仅使用支持稀疏输入的层,如何将稀疏张量作为输入传递给 Keras 模型。
x = tf.keras.Input(shape=(4,), sparse=True)
y = tf.keras.layers.Dense(4)(x)
model = tf.keras.Model(x, y)
sparse_data = tf.sparse.SparseTensor(
indices = [(0,0),(0,1),(0,2),
(4,3),(5,0),(5,1)],
values = [1,1,1,1,1,1],
dense_shape = (6,4)
)
model(sparse_data)
model.predict(sparse_data)
1/1 [==============================] - 0s 88ms/step array([[-0.08782148, 0.12210894, 0.17537934, 0.3671844 ], [ 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. ], [ 0.16802937, 0.04540855, 0.08590454, 0.69287664], [-0.8176182 , -0.35538226, -0.17415595, 0.4200449 ]], dtype=float32)
tf.data
tf.data
API 可用于通过简单的可重用代码段构建复杂的输入流水线。它的核心数据结构是 tf.data.Dataset
,表示一系列元素,每个元素包含一个或多个分量。
使用稀疏张量构建数据集
使用用于从 tf.Tensor
或 NumPy 数组构建数据集的相同方法从稀疏张量构建数据集,例如 tf.data.Dataset.from_tensor_slices
。此运算保留了数据的稀疏性。
dataset = tf.data.Dataset.from_tensor_slices(sparse_data)
for element in dataset:
print(pprint_sparse_tensor(element))
<SparseTensor shape=[4] values={ [0]: 1 [1]: 1 [2]: 1}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={ [3]: 1}> <SparseTensor shape=[4] values={ [0]: 1 [1]: 1}>
批处理和取消批处理具有稀疏张量的数据集
可以分别使用 Dataset.batch
和 Dataset.unbatch
方法批处理(将连续元素组合成单个元素)和取消批处理具有稀疏张量的数据集。
batched_dataset = dataset.batch(2)
for element in batched_dataset:
print (pprint_sparse_tensor(element))
<SparseTensor shape=[2, 4] values={ [0, 0]: 1 [0, 1]: 1 [0, 2]: 1}> <SparseTensor shape=[2, 4] values={}> <SparseTensor shape=[2, 4] values={ [0, 3]: 1 [1, 0]: 1 [1, 1]: 1}>
unbatched_dataset = batched_dataset.unbatch()
for element in unbatched_dataset:
print (pprint_sparse_tensor(element))
<SparseTensor shape=[4] values={ [0]: 1 [1]: 1 [2]: 1}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={ [3]: 1}> <SparseTensor shape=[4] values={ [0]: 1 [1]: 1}>
还可以使用 tf.data.experimental.dense_to_sparse_batch
将不同形状的数据集元素批处理为稀疏张量。
转换具有稀疏张量的数据集
使用 Dataset.map
在数据集中转换和创建稀疏张量。
transform_dataset = dataset.map(lambda x: x*2)
for i in transform_dataset:
print(pprint_sparse_tensor(i))
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 <SparseTensor shape=[4] values={ [0]: 2 [1]: 2 [2]: 2}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={}> <SparseTensor shape=[4] values={ [3]: 2}> <SparseTensor shape=[4] values={ [0]: 2 [1]: 2}>
tf.train.Example
tf.train.Example
是 TensorFlow 数据的标准 protobuf 编码。将稀疏张量与 tf.train.Example
搭配使用时,您可以:
使用
tf.io.VarLenFeature
将可变长度数据读入tf.sparse.SparseTensor
。但是,您应当考虑改用tf.io.RaggedFeature
。使用
tf.io.SparseFeature
将任意稀疏数据读入tf.sparse.SparseTensor
,它使用三个独立的特征键来存储indices
、values
和dense_shape
。
tf.function
tf.function
装饰器为 Python 函数预先计算 TensorFlow 计算图,这样可以大幅提升 TensorFlow 代码的性能。稀疏张量能够透明地与 tf.function
和具体函数一起使用。
@tf.function
def f(x,y):
return tf.sparse.sparse_dense_matmul(x,y)
a = tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]],
values=[15, 25],
dense_shape=[3, 10])
b = tf.sparse.to_dense(tf.sparse.transpose(a))
c = f(a,b)
print(c)
tf.Tensor( [[225 0 0] [ 0 0 0] [ 0 0 625]], shape=(3, 3), dtype=int32)
区分缺失值和零值
tf.sparse.SparseTensor
上的大多数运算都以相同的方式处理缺失值和显式零值。这是特意这样设计的 – tf.sparse.SparseTensor
的行为应当像密集张量一样。
但是,在少数情况下,区分零值和缺失值会十分有用。特别是,这样可以对训练数据中的缺失/未知数据进行编码。例如,考虑一个用例,其中包含一个分数张量(可以具有从 -Inf 到 +Inf 的任何浮点值),但缺少一些分数。可以使用稀疏张量对此张量进行编码,其中显式零是已知的零分数,但隐式零值实际上表示缺失数据,而不是零。
注:这通常不是 tf.sparse.SparseTensor
的预期用途;并且您可能还想考虑其他技术来对此进行编码,例如使用单独的掩码张量来识别已知/未知值的位置。但是,在使用这种方式时要格外小心,因为大多数稀疏运算将以相同的方式处理显式和隐式零值。
请注意,像 tf.sparse.reduce_max
这样的一些运算不会将缺失值视为零。例如,运行下面的代码块时,预期输出为 0
。但是,由于此异常,输出为 -3
。
print(tf.sparse.reduce_max(tf.sparse.from_dense([-5, 0, -3])))
tf.Tensor(-3, shape=(), dtype=int32)
相反,当您将 tf.math.reduce_max
应用于密集张量时,输出如预期的那样为 0。
print(tf.math.reduce_max([-5, 0, -3]))
tf.Tensor(0, shape=(), dtype=int32)
补充阅读和资源
- 请参阅张量指南来了解张量。
- 阅读不规则张量指南以了解如何使用不规则张量,这是一种可以处理非均匀数据的张量。
- 在 TensorFlow Model Garden 中查看此目标检测模型,此模型在
tf.Example
数据解码器中使用稀疏张量。