CDIMC-Net[1] 中有个对整个数据集求 kNN 图的函数 get_kNNgraph2
[2],是用 dense 的 numpy.ndarray
存的,空间复杂度
O
(
n
2
)
O(n^2)
O(n2),大数据集很吃内存,但其实 kNN 图很稀疏。这里用 scipy.sparse 的 API 改写。
Code
- csr_matrix:row slicing 高效,因为一行对应一个 datum 的邻接链表,取 batch 是对行取,所以用它。
- lil_matrix:说是「改变稀疏结构很高效」,用在图的构造时,构造完再转
csr_matrix
(本来直接用csr_matrix
构造,然后它建议用lil_matrix
)。
import numpy as np
from scipy.sparse import csr_matrix, lil_matrix
# import torch
def get_kNNgraph2(data,K_num):
"""原来的构图函数
https://github.com/DarrenZZhang/CDIMC-Net/blob/main/CDIMC-net-handwritten_final.py#L46
"""
# each row of data is a sample
x_norm = np.reshape(np.sum(np.square(data), 1), [-1, 1]) # column vector
x_norm2 = np.reshape(np.sum(np.square(data), 1), [1, -1]) # column vector
dists = x_norm - 2 * np.matmul(data, np.transpose(data))+x_norm2
num_sample = data.shape[0]
graph = np.zeros((num_sample,num_sample),dtype = np.int)
for i in range(num_sample):
distance = dists[i,:]
small_index = np.argsort(distance)
graph[i,small_index[0:K_num]] = 1
graph = graph-np.diag(np.diag(graph))
resultgraph = np.maximum(graph,np.transpose(graph))
return resultgraph
def get_kNNgraph2_sparse(X, K_num, batch_size=256):
"""sparse version of kNN graph calculation"""
n = X.shape[0] # full size
# `(n, n)` NOT `[n, n]`
G = lil_matrix((n, n), dtype=np.int8)
x_norm_all = np.sum(np.square(X), axis=1, keepdims=True).T # [1, n]
for _begin in range(0, n, batch_size):
_end = min(_begin + batch_size, n)
X_batch = X[_begin: _end]
# euclidean distance
x_norm = np.sum(np.square(X_batch), axis=1, keepdims=True) # [batch_size, 1]
D = x_norm - 2 * np.matmul(X_batch, np.transpose(X)) + x_norm_all # [batch_size, n]
small_index = np.argsort(D, axis=1)[:, :K_num] # [batch_size, K_num]
# mask the kNN
for i in range(small_index.shape[0]):
_row_id = _begin + i
_small_idx = small_index[i]
G[_row_id, _small_idx] = 1
# no self-loop
G.setdiag(0)
# symmetrize
G = G.maximum(G.transpose())
# convert to `csr_matrix` for fast row slicing
G = G.tocsr()
return G
"""验证一致性"""
N = 6 # num of data
D = 3 # data dim
K = N // 2
for i in range(150):
# print(i)
X = np.random.permutation(N * D).reshape(N, D)
G1 = get_kNNgraph2(X, K)
G2 = get_kNNgraph2_sparse(X, K).todense()
diff = (G1 != G2).sum()
if diff != 0:
print("diff:", i, diff) # 无输出
# print("PyTorch sparse matrix")
# x_nz, y_nz = G2.nonzero()
# I = torch.cat([
# torch.from_numpy(x_nz),
# torch.from_numpy(y_nz),
# ], 0).long()
# V = torch.ones(x_nz.shape[0]).float()
# break
print("DONE")