np.partition介绍

前言

这次分享一个 numpy 里面的一个高级函数partition,这个函数在一些搜索、匹配、找相关性的时候会用到。功能强大,但是一般人不知道、不会用,或者不知道怎么用。

这次就分享一下具体的用法,也是numpy技巧第二篇文章

同时代码也都是开源的,链接为:https://github.com/yuanzhoulvpi2017/tiny_python/blob/main/numpy_base,文件编号是02开通的。

介绍

np.partition是对一个向量、沿着一个维度方向、按照大小对数据进行分堆,分成了两堆。

  1. kth前面的这堆数值,都是这个向量里面比较小的群体们。
  2. kth后面的这堆数值,都是这个向量里面比较大的群体们。
一个简单的例子:
data = np.array([232, 564, 278, 3, 2, 1, -1, -10, -30, -40])
np.partition(data, kth=4)

np.partition(data, kth=4)的意思就是:

  1. np.partitiondata说:“你们班,现在给我从左向右站好!我不需要你们完全从低到高排序好,我只要左边 4 个是你们里面最小的就行”
  2. 然后data就找到班里最小的 4 个同学:-1, -10, -30, -40,说:“你们几个赶快给我站到左边,不需要你们几个再排序了,怎么快怎么来”
  3. -1, -10, -30, -40听到data指令后,马上跑到左边站好。
  4. 剩下的没有被指出来的,想怎么站都无所谓
与此类似的
  1. np.partition功能和np.argpartition功能是一样的,只不过np.argpartition返回的是序号
np.argpartition(data, kth=4)

np.argpartition的意思就是:

  1. np.argpartitiondata说:“把你们班最差的几个人学号放在前 4 个坑,剩下人随便填上就行了。
  2. data-1, -10, -30, -40说,你们几个差学生也别出来了,就对我说你们学号多少,我来填上,然后-1说我是 6 号,-10说我说 7 号,-30说我是 8 号,-40说我是 9 号。
  3. 然后剩下的人的编号,data敷衍了事,随便写上了。
解决问题

那么我们要想找到data的最小的 4 个数字,其实非常简单。两个方法:

# way 1
np.partition(data, kth=4)[:4]

# way 2
data[np.argpartition(data, kth=4)[:4]]

推广扩展

那么问题来了,我想找到data最大的 3 位数怎么办?

  1. 一般第 k 个,我们在 python 里面都是使用 k 为正数,也就是从左向右数第 k 个。
  2. 在 python 取后 k 个,其实我们都是知道的,就是使用-k这个方法。

那么这个方法其实在这里也是适用的,下面就是解决方法。就不过多做解释了。(注意返回的结果的后三个数值)

np.partition(data, kth=-3)# 返回的是具体的值

np.argpartition(data, kth=-3) # 返回的是值对应的序号

我要是取后 3 个序号,其实就是 top3 的值了:

# way1
np.partition(data, kth=-3)[-3:]

# way 2
data[np.argpartition(data, kth=-3)[-3:]]

更高维度怎么办

上面的data只是一维的,对于二维及更高维度的数据同样适用。
这里要注意一个小细节:

  1. 假设一个数组的 shape 是: (m,n,z)
    那么axis=1的方向其实就是沿着第二个也就是n这个方向。希望可以帮助读者分清楚.

实际问题

假如你是一个研究疾病的研究生,手上有个数据:

  1. 有一系列数据:其中有 425 个疾病名称,有 13426 个症状
  2. 还有一个疾病名称症状权重matrix,矩阵的 shape 为425 x 13426

需要解决的问题是:
需要按照权重matrix找到每个疾病名称前 10 个最相关的症状,并且记录下来。

解析
  1. 这里需要处理的数组变成了二维数组,找 top10(不需要排序,只要找到),并且记录下来。
  2. 这里使用np.argpartition可以一次性将所有的 topk 找出来,大大的提高了计算效率
这里分享代码

import numpy as np
import pandas as pd
from tqdm import tqdm

# generate sample data
n_features = 13426
n_disease = 425
features = [f"feature_{i}" for i in range(n_features)]
disease = [f"disease_{i}" for i in range(n_disease)]
weights = np.random.random((n_disease, n_features))


#function

def getdata(top_k: int) -> pd.DataFrame:
    index = np.argpartition(weights, -top_k, axis=1)[:, -top_k:]

    def slice_data(i):
        temp_data = pd.DataFrame({
            'features': np.array(features)[index[i, :]]})
        temp_data['disease'] = disease[i]
        temp_data['weights'] = weights[i, index[i, :]]
        return temp_data

    res = pd.concat([slice_data(i) for i in tqdm(range(weights.shape[0]))]).reset_index(drop=True)
    return res


final_data = getdata(top_k=3) # 这里只是找top3的,要是想找top10的,修改数值就行了
final_data.shape
final_data.head(4)
结果如下

  • 5
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

yuanzhoulvpi

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值