def collect_NotMask_sents(y_true,y_out,mask):
'''
y_pred,y_true: shape = (batch, max_len, n_class)
mask.shape = (batch, max_len,)
mask 中pad 部分为 0,真实存在句子则为 1
return:
all_true,all_out, 非mask 的句子,shape = [n,n_class]
'''
batch_size = y_true.shape[0]
n_class = y_true.shape[-1]
lens = torch.sum(mask,axis = -1).long() # batch 个样本去除pad的句子之后真实句子个数
all_true = torch.tensor(np.zeros(shape = [0,n_class]))
all_out = torch.tensor(np.zeros(shape = [0,n_class]))
for i in range(batch_size):
cur_true,cur_out = y_true[i,:lens[i]], y_out[i,:lens[i]]
if i == 0:
all_true, all_out = cur_true,cur_out
else:
all_true = torch.cat((all_true,cur_true),axis = 0)
all_out = torch.cat((all_out,cur_out),axis = 0)
return all_true, all_out
example:
y_out.shape = [3,6,4]
tensor([[[-0.0749, -0.3202, -1.3574, -1.0878],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000]],
[[-0.0620, -0.5731, -1.7255, -2.6526],
[ 0.5735, -1.9126, -1.8282, -2.7244],
[ 0.5664, -0.9909, -2.0089, -2.1861],
[ 3.2328, -3.2824, -3.3367, -2.7656],
[ 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000]],
[[ 0.3660, -1.6101, -1.7406, -1.3969],
[ 0.9982, -1.4341, -1.3564, -1.2066],
[ 1.4501, -1.1031, -2.2286, -1.0969],
[ 1.1524, -1.9015, -1.4164, -1.5932],
[ 0.3268, -0.8334, -1.5926, -1.4561],
[ 0.0000, 0.0000, 0.0000, 0.0000]]])
y_true.shape = [3,6,4]
tensor([[[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[1., 0., 1., 0.],
[0., 1., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]],
[[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 0., 1., 0.],
[0., 0., 0., 0.]]])
mask.shape = [3,6]
tensor([[1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0., 0.],
[1., 1., 1., 1., 1., 0.]])
all_true, all_out = collect_NotMask_sents(y_true,y_out,mask)
print (all_true, all_out)
all_true:
tensor([[1., 0., 0., 0.],
[1., 0., 1., 0.],
[0., 1., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 0., 1., 0.]])
all_out:
tensor([[-0.0749, -0.3202, -1.3574, -1.0878],
[-0.0620, -0.5731, -1.7255, -2.6526],
[ 0.5735, -1.9126, -1.8282, -2.7244],
[ 0.5664, -0.9909, -2.0089, -2.1861],
[ 3.2328, -3.2824, -3.3367, -2.7656],
[ 0.3660, -1.6101, -1.7406, -1.3969],
[ 0.9982, -1.4341, -1.3564, -1.2066],
[ 1.4501, -1.1031, -2.2286, -1.0969],
[ 1.1524, -1.9015, -1.4164, -1.5932],
[ 0.3268, -0.8334, -1.5926, -1.4561]])