Pytorch geometric中SparseTensor的三种压缩存储方式
最近在研究Pytorch geometric中的SparseTensor
对象,它是Pytorch_geometric内部用于存储稀疏矩阵的对象,它可以以三种不同的压缩存储方式来保存稀疏矩阵:COO、CSR、CSC。本文简单介绍一下这三种压缩存储方式。
1. Coordinate Format (COO)
这种存储方式最直接,它使用3个数组来存储一个稀疏矩阵。通过row
和 col
数组指定元素的行索引和列索引,values
中对应的值就是元素值。
-
row indices
: 存储每个元素的行索引 -
col indices
: 存储每个元素的列索引 -
values
: 存储每个元素的值
>>> sp
SparseTensor(row=tensor([0, 0, 1, 1, 2, 2, 2, 3, 3]),
col=tensor([0, 1, 1, 2, 0, 2, 3, 1, 3]),
val=tensor([1, 7, 2, 8, 5, 3, 9, 6, 4]),
size=(4, 4), nnz=9, density=56.25%)
>>> sp.to_dense()
tensor([[1, 7, 0, 0],
[0, 2, 8, 0],
[5, 0, 3, 9],
[0, 6, 0, 4]])
>>> sp.coo()
(row = tensor([0, 0, 1, 1, 2, 2, 2, 3, 3]),
col = tensor([0, 1, 1, 2, 0, 2, 3, 1, 3]),
val = tensor([1, 7, 2, 8, 5, 3, 9, 6, 4]))
2. Compressed Sparse Row Format (CSR)
这种存储方式稍微复杂一些,它同样是使用3个数组来保存一个稀疏矩阵:row ptr
、column indices
和values
。换个角度理解,我们可以认为CSR就是在COO的基础上,**将row数组进行压缩,另外两个数组保持不变。**在原来的COO中,相同行的元素会在row保存重复的行索引,所以我们在row中将重复的行索引删去,用row中的元素来指定当前行中所有非零元素在values中的范围。从而删去冗余的行索引。
row ptr
指定每一行的第一个非零元素在values
(或者是column indices
)中对应的索引。例如row_ptr[1]
的值表示为第二行的第一个非零元素在values
中的索引。稀疏矩阵第n行中所有的非零元素在values
中的索引范围为:[row_ptr[n], row_ptr[n+1]),意思是第n行中所有的非零元素按顺序保存在values[row_ptr[n]]
至values[row_ptr[n+1]]
中。column indices
指定values
中对应元素的行索引。。values
保存数值。
>>> sp
SparseTensor(row=tensor([0, 0, 1, 1, 2, 2, 2, 3, 3]),
col=tensor([0, 1, 1, 2, 0, 2, 3, 1, 3]),
val=tensor([1, 7, 2, 8, 5, 3, 9, 6, 4]),
size=(4, 4), nnz=9, density=56.25%)
>>> sp.to_dense()
tensor([[1, 7, 0, 0],
[0, 2, 8, 0],
[5, 0, 3, 9],
[0, 6, 0, 4]])
>>> sp.csr()
(row_ptr = tensor([0, 2, 4, 7, 9]),
col_ind = tensor([0, 1, 1, 2, 0, 2, 3, 1, 3]),
values = tensor([1, 7, 2, 8, 5, 3, 9, 6, 4]))
3. Compressed Sparse Column Format (CSC)
这种存储方式稍微复杂一些,同样是使用3个数组来保存一个稀疏矩阵:col ptr
、row indices
和values
。它的压缩方式与CSR相似,只不过CSR是按行压缩,而CSC是按列压缩:
col ptr
指定每一列的第一个非零元素在values
(或者是row indices
)中对应的索引。例如col_ptr[1]
的值表示为第二列的第一个非零元素在values
中的索引。稀疏矩阵第n列中所有的非零元素在values
中的索引范围为:[col_ptr[n], col_ptr[n+1]),意思是第n列中所有的非零元素按顺序保存在values[col_ptr[n]]
至values[col_ptr[n+1]]
中。row indices
指定values
中对应元素的行索引。values
保存数值。
>>> sp
SparseTensor(row=tensor([0, 0, 1, 1, 2, 2, 2, 3, 3]),
col=tensor([0, 1, 1, 2, 0, 2, 3, 1, 3]),
val=tensor([1, 7, 2, 8, 5, 3, 9, 6, 4]),
size=(4, 4), nnz=9, density=56.25%)
>>> sp.to_dense()
tensor([[1, 7, 0, 0],
[0, 2, 8, 0],
[5, 0, 3, 9],
[0, 6, 0, 4]])
>>> sp.csc()
(col_ptr = tensor([0, 2, 5, 7, 9]),
row_ind = tensor([0, 2, 0, 1, 3, 1, 2, 2, 3]),
values = tensor([1, 5, 7, 2, 6, 8, 3, 9, 4])