最近在使用pack_padded_sequence出现了RuntimeError: Length of all samples has to be greater than 0, but found an element in ‘lengths’ that is <= 0这个错误,刚开始百思不得其解,后来发现 问题出现在pack_padded_sequence(seq, seq_lengths, batch_first=True, enforce_sorted=False),里面的参数seq_lengths中,下面举个例子。
import torch
from torch.nn.utils.rnn import pad_sequence,pack_padded_sequence,pack_sequence,pad_packed_sequence
out = torch.tensor([[1,2,3,5],[3,5,6,0],[0,0,0,0]])
seq_lens = [4,3,0]
out = pack_padded_sequence(out, seq_lens, batch_first =True,enforce_sorted=False)
print(out)
这段代码就会报这个RuntimeError: Length of all samples has to be greater than 0, but found an element in ‘lengths’ that is <= 0错误,主要是因为表示长度的lens列表中有0,所以会报这个错。