稀疏矩阵(Sparse Matrix)是线性代数和计算机科学中的一个重要概念,广泛应用于科学计算、工程模拟、图像处理、机器学习等多个领域。与稠密矩阵(Dense Matrix)相比,稀疏矩阵大部分元素为零,仅有少数非零元素。这一特性使得稀疏矩阵在存储和计算上具有显著的优势,尤其在处理大规模数据时更为高效。
一、稀疏矩阵的定义与性质
1. 定义
稀疏矩阵是指在一个矩阵中
,大多数元素为零,只有少数元素为非零值的矩阵
。形式上,给定一个 m × n m \times n m×n 的矩阵 A A A,如果 A A A 中非零元素的数量远小于 m × n m \times n m×n,则称 A A A 为稀疏矩阵。
2. 稠密矩阵与稀疏矩阵的对比
- 稠密矩阵(Dense Matrix):
矩阵中大部分元素为非零值
。例如,一个 1000 × 1000 1000 \times 1000 1000×1000 的稠密矩阵大约有 1 0 6 10^6 106 个非零元素。 - 稀疏矩阵(Sparse Matrix):矩阵中大部分元素为零。例如,一个 1000 × 1000 1000 \times 1000 1000×1000 的稀疏矩阵可能只有 1 0 3 10^3 103 个非零元素。
3. 稀疏矩阵的性质
- 非零元素稀少:稀疏矩阵中的非零元素数量远小于矩阵的总元素数量。
- 结构特性:稀疏矩阵往往具有特定的结构特性,如
对角线、带状、块状
等。 - 存储与计算效率:由于非零元素稀少,可以采用
专门的存储格式和算法
,提高存储和计算效率。
二、稀疏矩阵的存储格式
为了高效地存储和操作稀疏矩阵,研究人员设计了多种存储格式。这些格式主要通过仅存储非零元素及其位置信息,减少内存占用和提高访问速度。以下是几种常见的稀疏矩阵存储格式:
1. 压缩稀疏行(Compressed Sparse Row, CSR)
压缩稀疏行(Compressed Sparse Row, CSR)是一种存储稀疏矩阵的有效数据结构,常用于科学计算、机器学习、自然语言处理等领域,特别是处理大规模稀疏矩阵时。稀疏矩阵是指大多数元素为零的矩阵,使用常规的二维数组存储会浪费大量内存,而CSR格式可以显著降低内存的使用,并提高计算效率。
1. CSR格式的结构
CSR格式通过三种数组来存储一个稀疏矩阵的非零元素及其位置信息。假设矩阵有 m m m 行, n n n 列,总共有 z z z 个非零元素,CSR格式的存储结构由以下三个数组组成:
values
:存储矩阵的所有非零元素。column_indices
:存储对应非零元素的列索引,和values
数组中的元素一一对应。row_ptr
:存储每一行的起始位置索引,表示每一行开始的非零元素在values
数组中的位置。
2. CSR格式的存储结构详解
考虑一个 m × n m \times n m×n 的稀疏矩阵 A A A,其中 z z z 个元素非零,矩阵的内容如下所示:
0 | 1 | 2 | 3 | 4 | |
---|---|---|---|---|---|
0 | 1 | 0 | 0 | 0 | 2 |
1 | 0 | 0 | 3 | 0 | 4 |
2 | 5 | 0 | 0 | 0 | 6 |
2.1. values
(非零元素值数组)
values
数组按行优先的顺序存储矩阵中的非零元素,忽略零元素
。例如,矩阵中的非零元素为:1, 2, 3, 4, 5, 6,因此:
values = [1, 2, 3, 4, 5, 6]
2.2. column_indices(列索引数组)
column_indices
数组存储每个非零元素对应的列索引
。注意,column_indices
中的索引与 values
数组中的元素一一对应。例如:
-
第一个非零元素 1 在第0行的第0列,因此
column_indices[0] = 0
。 -
第二个非零元素 2 在第0行的第4列,因此
column_indices[1] = 4
。 -
第三个非零元素 3 在第1行的第1列,因此
column_indices[2] = 1
。 -
以此类推,最终得到:
column_indices = [0, 4, 1, 3, 2, 4]
2.3. row_ptr(行指针数组)
row_ptr
数组表示每一行中第一个非零元素
在 values
数组中的位置。它的长度是 m + 1 m + 1 m+1(矩阵行数加1),因为最后一个元素表示最后一行非零元素的结尾。row_ptr[i]
表示第 i i i 行的第一个非零元素在 values
数组中的索引。矩阵中每一行的非零元素分布如下:
- 第0行的非零元素为 1 和 2,它们在
values
数组中的位置分别是索引 0 和 1。 - 第1行的非零元素为 3 和 4,它们在
values
数组中的位置分别是索引 2 和 3。 - 第2行的非零元素为 5 和 6,它们在
values
数组中的位置分别是索引 4 和 5。
因此,row_ptr
数组如下:
row_ptr = [0, 2, 4, 6]
3. CSR格式存储矩阵示例
将上述矩阵转换为CSR格式后,矩阵 A A A 变为:
values = [1, 2, 3, 4, 5, 6]
column_indices = [0, 4, 1, 3, 2, 4]
row_ptr = [0, 2, 4, 6]
这个表示法有效地将稀疏矩阵的非零元素及其位置信息存储在三个一维数组中。由于稀疏矩阵中大部分元素是零,因此这种存储方式大大节省了内存。
4.查找矩阵中特定元素
假设我们要查找矩阵中的元素 A [ i , j ] A[i, j] A[i,j](即第 i i i 行,第 j j j 列的元素)。
-
确定目标行:首先,根据目标行 i i i 查找该行中非零元素的位置。可以通过
row_ptr[i]
获取该行的第一个非零元素在values
中的索引。row_ptr[i+1]
则表示该行非零元素的结束位置。例如:
- 对于第 0 行,
row_ptr[0] = 0
,row_ptr[1] = 2
,表示第 0 行的非零元素在values
数组中的位置范围是从索引 0 到 1(即values[0]
和values[1]
)。 - 对于第 1 行,
row_ptr[1] = 2
,row_ptr[2] = 4
,表示第 1 行的非零元素在values
数组中的位置范围是从索引 2 到 3(即values[2]
和values[3]
)。
- 对于第 0 行,
-
查找列索引:通过
column_indices
数组,查找该行非零元素的列索引。例如,column_indices
数组中的值对应于非零元素在该行中的列位置。 -
查找元素是否存在:如果列索引数组中包含目标列 j j j,那么就可以找到该元素;否则,目标元素为零。
示例查找
假设我们要查找 A [ 1 , 3 ] A[1, 3] A[1,3],即第 1 行,第 3 列的元素。
-
确定第 1 行非零元素的索引范围:
row_ptr[1] = 2
和row_ptr[2] = 4
,说明第 1 行的非零元素索引位于values[2]
到values[3]
,即values[2] = 3
和values[3] = 4
。 -
检查列索引:查看
column_indices[2] = 1
和column_indices[3] = 3
。目标列索引 3 在column_indices[3]
中找到。 -
返回对应的元素:在
values[3]
中找到对应的非零元素 4。
因此,get_element(1, 3)
将返回值 4。
import numpy as np
from scipy.sparse import csr_matrix
# Step 1: 创建一个稀疏矩阵
dense_matrix = np.array([[1, 0, 0, 0, 2],
[0, 3, 0, 4, 0],
[0, 0, 5, 0, 6]])
# 将稀疏矩阵转换为CSR格式
csr = csr_matrix(dense_matrix)
# 提取CSR格式的三个数组
values = csr.data
column_indices = csr.indices
row_ptr = csr.indptr
# 输出CSR表示
print("Values:", values) # 非零元素的值
print("Column indices:", column_indices) # 非零元素的列索引
print("Row pointer:", row_ptr) # 每行第一个非零元素的索引
# Step 2: 查找函数实现
# 查找矩阵中特定的元素 A[i, j]
def get_element(i, j, row_ptr, column_indices, values):
# 查找第i行的非零元素的索引范围
start_idx = row_ptr[i] # 第i行第一个非零元素的索引
end_idx = row_ptr[i + 1] # 第i行最后一个非零元素的索引 + 1
# 在该范围内查找是否有列索引等于j
for idx in range(start_idx, end_idx):
if column_indices[idx] == j:
return values[idx] # 返回对应的非零元素值
# 如果没有找到,返回0,表示该位置的元素为零
return 0
# 示例:查找 A[1, 3] 元素
print(f"A[1, 3] = {
get_element(1, 3, row_ptr, column_indices, values)}")
# 查找所有非零元素及其位置
def get_all_non_zero_elements(row_ptr, column_indices, values):
non_zero_elements = []
for i in range(len(row_ptr) - 1): # 遍历每一行
start_idx = row_ptr[i]
end_idx = row_ptr[i + 1]
for idx in range(start_idx, end_idx):
# 获取非零元素的行索引、列索引和值
non_zero_elements.append((i, column_indices[idx], values[idx]))
return non_zero_elements
# 输出所有非零元素
print("All non-zero elements:", get_all_non_zero_elements(row_ptr, column_indices, values))
# 查找第1行的所有非零元素
def get_row_non_zero_elements(i, row_ptr, column_indices, values):
start_idx = row_ptr[i]
end_idx = row_ptr[i + 1]
row_elements = []
for idx in range(start_idx, end_idx):
row_elements.append((column_indices[idx], values[idx])) # 存储列索引和非零值
return row_elements
print(f"Non-zero elements in row 1: {
get_row_non_zero_elements(1, row_ptr, column_indices, values)}")
# 查找第4列的所有非零元素
def get_column_non_zero_elements(j, row_ptr, column_indices, values):
column_elements = []
for i in range(len(row_ptr) - 1): # 遍历每一行
start_idx = row_ptr[i]
end_idx = row_ptr[i + 1]
for idx in range(start_idx, end_idx):
if column_indices[idx] == j:
column_elements.append((i, values[idx])) # 存储行索引和非零值
return column_