import torch
# 2个样本
# 3个最大长度
# dim = 4
x = torch.tensor([[[1.0, 1.0, 3.0, 4.0], [1.0, 1.0, 3.0, 4.0], [1.0, 1.0, 3.0, 4.0]], \
[[1.0, 1.0, 3.0, 4.0], [1.0, 1.0, 3.0, 4.0], [1.0, 1.0, 3.0, 4.0]]])
print(x)
a = []
true_length = torch.tensor([2, 1])
print(true_length)
mask = torch.arange(3)[None, :] < true_length[:, None]
mask = mask.float()
print(mask)
mask = mask.unsqueeze(2)
print(mask)
print(x*mask)
a.append(torch.sum(x*mask, axis=1))
a.append(torch.sum(x*mask, axis=1))
print(a)
torch.concat(a, axis=1)
pytorch 不定长序列 mask后 sum
sentence_lengths = torch.Tensor([7, 10, 4]) # 代表每个句子的长度
print(sentence_lengths)
mask = torch.arange(sentence_lengths.max().item())[None, :] < sentence_lengths[:, None]
print(mask)
pytorch 不定长序列 mask后 sum