使用python实现对样本的分层均衡抽样

前言

  • 文章来源:CSDN@LawsonAbs
  • 代码见我Github

1. 需求

在深度学习中,我们时常会碰到样本不均衡的情况,这种情况下,如果我们的样本没有给出dev集合,那么就需要手动分割train得到一份dev,但是dev的分割最好不要随机实现,而是要按照类别的个数分类取样。众所周知,sklearn中实现了常用的机器学习方法,但是 sklearn 中没有相应的函数实现此需求,【只有一个叫做 StratifiedKFold 的类,它只能做到尽可能按照类别取数,但还是有遗漏】
所以我们自己实现一份分层抽样算法,实现如下。

2.代码

def split_data_by_class(x,y,rate,seed=22):
    '''
    x,y 表示训练数据的输入和标签
    如果random = True, 则随机分割~
    rate表示dev所占数据的比例,如果不足1条,则按1条处理

    returns:
        train_idx,dev_idx
    '''    
    # 先shuffle一下再说,保持相同的shuffle  
    random.seed(seed)
    random.shuffle(y)
    random.seed(seed)
    random.shuffle(x)

    cont_id = {} # 每个类别都放到一个list中    
    # 保持每个类别划分相同
    for i in range(len(y)):
        y_idx = y[i] # y_idx是个分类值
        if y_idx not in cont_id.keys():
            cont_id[y_idx] = []
        cont_id[y_idx].append(i) # 将该类别的下标放到其中
    
    train_idx,dev_idx = [],[] # 最后返回的值,是一个下标
    # 保持按照类别均匀抽取,一共有35个类
    for item in cont_id.items():
        key,value = item # key 是类别信息,value是该类别的所有(在y中)下标
        if len(value) >= 2: # 如果当前类别的样本数大于2,才分成dev
            mid = int(len(value)*rate)
            if mid == 0: # 如果只有0,那么就得将其改成1
                mid = 1
            for i in range(mid): # 按照个数取前 rate * len(value) 个
                dev_idx.append(value[i])         
            for i in range(mid,len(value)): # 从上次的开始,作为train 集合
                train_idx.append(value[i])
        else:            
             train_idx.extend(value) # 将整个值都作为train
    
    # 得到值后再次shuffle,避免输出连续相同的值
    random.seed(seed)
    random.shuffle(train_idx)
    random.seed(seed)
    random.shuffle(dev_idx)
    
    # 需要判断一下最后返回的dev_idx 的长度是否符合要求,如果不合要求,则要随机再抽取一点儿到其中
    while len(dev_idx) < len(train_idx) * rate:
        idx = random.choice(train_idx)        
        dev_idx.append(idx) # 放到dev中
        train_idx.remove(idx) # 删除这个idx

    return train_idx,dev_idx # 返回train/dev 数据所在的下标
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

说文科技

看书人不妨赏个酒钱?

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值