pytorch: 根据mask标记,从pad的层级数据中取出非pad的 true 和 pred 数据

def collect_NotMask_sents(y_true,y_out,mask):
    '''
    y_pred,y_true: shape = (batch, max_len, n_class)
    mask.shape = (batch, max_len,)
    mask 中pad 部分为 0,真实存在句子则为 1
    
    return:
        all_true,all_out, 非mask 的句子,shape = [n,n_class]
    '''
    batch_size = y_true.shape[0]
    n_class = y_true.shape[-1]
    lens = torch.sum(mask,axis = -1).long() # batch 个样本去除pad的句子之后真实句子个数

    all_true = torch.tensor(np.zeros(shape = [0,n_class]))
    all_out = torch.tensor(np.zeros(shape = [0,n_class]))

    for i in range(batch_size): 
        cur_true,cur_out = y_true[i,:lens[i]], y_out[i,:lens[i]]
        if i == 0:
            all_true, all_out = cur_true,cur_out
        else:    
            all_true = torch.cat((all_true,cur_true),axis = 0)  
            all_out = torch.cat((all_out,cur_out),axis = 0)
 
    return all_true, all_out

example:

y_out.shape = [3,6,4]
tensor([[[-0.0749, -0.3202, -1.3574, -1.0878],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[-0.0620, -0.5731, -1.7255, -2.6526],
         [ 0.5735, -1.9126, -1.8282, -2.7244],
         [ 0.5664, -0.9909, -2.0089, -2.1861],
         [ 3.2328, -3.2824, -3.3367, -2.7656],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.3660, -1.6101, -1.7406, -1.3969],
         [ 0.9982, -1.4341, -1.3564, -1.2066],
         [ 1.4501, -1.1031, -2.2286, -1.0969],
         [ 1.1524, -1.9015, -1.4164, -1.5932],
         [ 0.3268, -0.8334, -1.5926, -1.4561],
         [ 0.0000,  0.0000,  0.0000,  0.0000]]]) 

y_true.shape = [3,6,4]
tensor([[[1., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 0., 1., 0.],
         [0., 1., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [1., 0., 0., 0.],
         [0., 0., 0., 0.],
         [1., 0., 1., 0.],
         [0., 0., 0., 0.]]]) 

mask.shape = [3,6]
tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.]])

all_true, all_out = collect_NotMask_sents(y_true,y_out,mask)
print (all_true, all_out)

all_true:
tensor([[1., 0., 0., 0.],
        [1., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.]])
all_out:
tensor([[-0.0749, -0.3202, -1.3574, -1.0878],
        [-0.0620, -0.5731, -1.7255, -2.6526],
        [ 0.5735, -1.9126, -1.8282, -2.7244],
        [ 0.5664, -0.9909, -2.0089, -2.1861],
        [ 3.2328, -3.2824, -3.3367, -2.7656],
        [ 0.3660, -1.6101, -1.7406, -1.3969],
        [ 0.9982, -1.4341, -1.3564, -1.2066],
        [ 1.4501, -1.1031, -2.2286, -1.0969],
        [ 1.1524, -1.9015, -1.4164, -1.5932],
        [ 0.3268, -0.8334, -1.5926, -1.4561]]) 
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值