前言
这次分享一个 numpy 里面的一个高级函数partition
,这个函数在一些搜索、匹配、找相关性的时候会用到。功能强大,但是一般人不知道、不会用,或者不知道怎么用。
这次就分享一下具体的用法,也是numpy技巧第二篇文章
。
同时代码也都是开源的,链接为:https://github.com/yuanzhoulvpi2017/tiny_python/blob/main/numpy_base,文件编号是02
开通的。
介绍
np.partition
是对一个向量、沿着一个维度方向、按照大小对数据进行分堆,分成了两堆。
- 在
kth
前面的这堆数值,都是这个向量里面比较小的群体们。 - 在
kth
后面的这堆数值,都是这个向量里面比较大的群体们。
一个简单的例子:
data = np.array([232, 564, 278, 3, 2, 1, -1, -10, -30, -40])
np.partition(data, kth=4)
np.partition(data, kth=4)
的意思就是:
np.partition
对data
说:“你们班,现在给我从左向右站好!我不需要你们完全从低到高排序好,我只要左边 4 个是你们里面最小的就行”- 然后
data
就找到班里最小的 4 个同学:-1
,-10
,-30
,-40
,说:“你们几个赶快给我站到左边,不需要你们几个再排序了,怎么快怎么来” -1
,-10
,-30
,-40
听到data
指令后,马上跑到左边站好。- 剩下的没有被指出来的,想怎么站都无所谓
与此类似的
np.partition
功能和np.argpartition
功能是一样的,只不过np.argpartition
返回的是序号
np.argpartition(data, kth=4)
np.argpartition
的意思就是:
np.argpartition
对data
说:“把你们班最差的几个人学号放在前 4 个坑,剩下人随便填上就行了。data
对-1
,-10
,-30
,-40
说,你们几个差学生也别出来了,就对我说你们学号多少,我来填上,然后-1
说我是 6 号,-10
说我说 7 号,-30
说我是 8 号,-40
说我是 9 号。- 然后剩下的人的编号,
data
敷衍了事,随便写上了。
解决问题
那么我们要想找到data
的最小的 4 个数字,其实非常简单。两个方法:
# way 1
np.partition(data, kth=4)[:4]
# way 2
data[np.argpartition(data, kth=4)[:4]]
推广扩展
那么问题来了,我想找到data
最大的 3 位数怎么办?
- 一般第 k 个,我们在 python 里面都是使用 k 为正数,也就是从左向右数第 k 个。
- 在 python 取后 k 个,其实我们都是知道的,就是使用
-k
这个方法。
那么这个方法其实在这里也是适用的,下面就是解决方法。就不过多做解释了。(注意返回的结果的后三个数值)
np.partition(data, kth=-3)# 返回的是具体的值
np.argpartition(data, kth=-3) # 返回的是值对应的序号
我要是取后 3 个序号,其实就是 top3 的值了:
# way1
np.partition(data, kth=-3)[-3:]
# way 2
data[np.argpartition(data, kth=-3)[-3:]]
更高维度怎么办
上面的data
只是一维的,对于二维及更高维度的数据同样适用。
这里要注意一个小细节:
- 假设一个数组的 shape 是: (m,n,z)
那么axis=1
的方向其实就是沿着第二个也就是n
这个方向。希望可以帮助读者分清楚.
实际问题
假如你是一个研究疾病的研究生,手上有个数据:
- 有一系列数据:其中有 425 个
疾病名称
,有 13426 个症状
- 还有一个
疾病名称
和症状
的权重matrix
,矩阵的 shape 为425 x 13426
需要解决的问题是:
需要按照权重matrix
找到每个疾病名称
前 10 个最相关的症状
,并且记录下来。
解析
- 这里需要处理的数组变成了二维数组,找 top10(不需要排序,只要找到),并且记录下来。
- 这里使用
np.argpartition
可以一次性将所有的 topk 找出来,大大的提高了计算效率
这里分享代码
import numpy as np
import pandas as pd
from tqdm import tqdm
# generate sample data
n_features = 13426
n_disease = 425
features = [f"feature_{i}" for i in range(n_features)]
disease = [f"disease_{i}" for i in range(n_disease)]
weights = np.random.random((n_disease, n_features))
#function
def getdata(top_k: int) -> pd.DataFrame:
index = np.argpartition(weights, -top_k, axis=1)[:, -top_k:]
def slice_data(i):
temp_data = pd.DataFrame({
'features': np.array(features)[index[i, :]]})
temp_data['disease'] = disease[i]
temp_data['weights'] = weights[i, index[i, :]]
return temp_data
res = pd.concat([slice_data(i) for i in tqdm(range(weights.shape[0]))]).reset_index(drop=True)
return res
final_data = getdata(top_k=3) # 这里只是找top3的,要是想找top10的,修改数值就行了
final_data.shape
final_data.head(4)