bool索引的使用
def delete_Pdata(data, label):
"""去除标签为1的片段,如果标签向量和大于0"""
for i in range(len(label)):
if label[i].sum()>0:
data=np.delete(data,i,axis=0)
label=np.delete(label,i,axis=0)
return data,label
背景:label是一个9000*1000的向量,希望删除其中标签中含有1的样本,最开始使用循环删除,发现:range(len(label))并不随下面的删除进行而动态改变,导致最终出现索引错误。
正确做法使用bool索引:
import numpy as np
#bool 索引
label=[0,1,1,1,0,0,1,0,1,0]
data=np.random.rand(10,2)
print(data)
data数据为一个10*2的矩阵,接下来创建bool 索引:
mask=(np.array(label)==1)
print(mask)
print(data[mask])
从下面的结果图片中可以看出,mask获得了一个与data行长度一致的bool序列,使用这个bool序列实现了对于data行向量的筛选
从而基于bool索引改进代码:
def delete_Pdata(data, label):
"""去除标签为1的片段,如果标签向量和大于0"""
# 创建一个布尔数组,标记需要保留的行
mask = label.sum(axis=1) <= 0
# 使用布尔索引选择需要保留的数据和标签
filtered_data = data[mask]
filtered_label = label[mask]
return filtered_data, filtered_label
np.where的使用
详细用法见:python numpy.where()函数的用法_np.where()函数的用法-CSDN博客
1、实现元素位置的查找
np.where(x>7)将会返回满足条件的元素索引
x[np.where(x>7)]将会返回满足条件的元素值
2、实现条件选择的元素替换
np.where(condition,x,0) 满足条件的仍然用原来的元素,不满足条件的则用0替换
np.where(condition ,1,0)满足条件的用0,不满足用1