scipy.sparse使用简例

CDIMC-Net[1] 中有个对整个数据集求 kNN 图的函数 get_kNNgraph2[2],是用 dense 的 numpy.ndarray 存的,空间复杂度 O ( n 2 ) O(n^2) O(n2),大数据集很吃内存,但其实 kNN 图很稀疏。这里用 scipy.sparse 的 API 改写。

Code

  • csr_matrix:row slicing 高效,因为一行对应一个 datum 的邻接链表,取 batch 是对行取,所以用它。
  • lil_matrix:说是「改变稀疏结构很高效」,用在图的构造时,构造完再转 csr_matrix(本来直接用 csr_matrix 构造,然后它建议用 lil_matrix)。
import numpy as np
from scipy.sparse import csr_matrix, lil_matrix
# import torch


def get_kNNgraph2(data,K_num):
    """原来的构图函数
    https://github.com/DarrenZZhang/CDIMC-Net/blob/main/CDIMC-net-handwritten_final.py#L46
    """
    # each row of data is a sample

    x_norm = np.reshape(np.sum(np.square(data), 1), [-1, 1])  # column vector
    x_norm2 = np.reshape(np.sum(np.square(data), 1), [1, -1])  # column vector
    dists = x_norm - 2 * np.matmul(data, np.transpose(data))+x_norm2
    num_sample = data.shape[0]
    graph = np.zeros((num_sample,num_sample),dtype = np.int)
    for i in range(num_sample):
        distance = dists[i,:]
        small_index = np.argsort(distance)
        graph[i,small_index[0:K_num]] = 1
    graph = graph-np.diag(np.diag(graph))
    resultgraph = np.maximum(graph,np.transpose(graph))
    return resultgraph


def get_kNNgraph2_sparse(X, K_num, batch_size=256):
    """sparse version of kNN graph calculation"""
    n = X.shape[0]  # full size
    # `(n, n)`  NOT `[n, n]`
    G = lil_matrix((n, n), dtype=np.int8)
    x_norm_all = np.sum(np.square(X), axis=1, keepdims=True).T  # [1, n]
    for _begin in range(0, n, batch_size):
        _end = min(_begin + batch_size, n)
        X_batch = X[_begin: _end]
        # euclidean distance
        x_norm = np.sum(np.square(X_batch), axis=1, keepdims=True)  # [batch_size, 1]
        D = x_norm - 2 * np.matmul(X_batch, np.transpose(X)) + x_norm_all  # [batch_size, n]
        small_index = np.argsort(D, axis=1)[:, :K_num]  # [batch_size, K_num]
        # mask the kNN
        for i in range(small_index.shape[0]):
            _row_id = _begin + i
            _small_idx = small_index[i]
            G[_row_id, _small_idx] = 1

    # no self-loop
    G.setdiag(0)
    # symmetrize
    G = G.maximum(G.transpose())
    # convert to `csr_matrix` for fast row slicing
    G = G.tocsr()
    return G


"""验证一致性"""
N = 6  # num of data
D = 3  # data dim
K = N // 2
for i in range(150):
    # print(i)
    X = np.random.permutation(N * D).reshape(N, D)

    G1 = get_kNNgraph2(X, K)
    G2 = get_kNNgraph2_sparse(X, K).todense()

    diff = (G1 != G2).sum()
    if diff != 0:
        print("diff:", i, diff)  # 无输出

    # print("PyTorch sparse matrix")
    # x_nz, y_nz = G2.nonzero()
    # I = torch.cat([
        # torch.from_numpy(x_nz),
        # torch.from_numpy(y_nz),
    # ], 0).long()
    # V = torch.ones(x_nz.shape[0]).float()
    # break

print("DONE")

References

  1. DarrenZZhang/CDIMC-Net
  2. get_kNNgraph2
  3. Sparse matrices (scipy.sparse)
  4. scipy.sparse.csr_matrix
  5. scipy.sparse.lil_matrix
  6. torch.sparse
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值