错误
The size of tensor a (3) must match the size of tensor b (5) at non-singleton dimension 1
在使用nn.utils.rnn.pad_sequence时,遇到如上错误,原因是使用方式错误.
使用说明
用padding_value填充可变长度张量列表
pad_sequence 沿新维度堆叠张量列表,
并将它们垫成相等的长度。
例如,如果输入是列表
大小为“L x *”的序列,如果batch_first为False,并且“T x B x *”
“B”是批量大小。它等于“序列”中元素的数量。
“T”是最长序列的长度。
“L”是序列的长度。
“*”是任意数量的尾随维度,包括没有。
例子:
>>> from torch.nn.utils.rnn import pad_sequence
>>> a = torch.ones(25, 300)
>>> b = torch.ones(22, 300)
>>> c = torch.ones(15, 300)
>>> pad_sequence([a, b, c]).size()
torch.Size([25, 3, 300])
注意:
该函数返回大小为“T x B x *”或“B x T x *”的张量
其中“T”是最长序列的长度。该函数假设
序列中所有张量的尾随维度和类型都是相同的。
参数:
序列 (list[Tensor]):可变长度序列的列表。
batch_first(bool,可选):如果为 True,输出将在“B x T x *”中,否则在
``T x B x *`` 否则。默认值:假。
padding_value (float,可选):填充元素的值。默认值:0。
返回:
如果:attr:`batch_first` 为``False``,则大小为``T x B x *`` 的张量。
大小为“B x T x *”的张量,否则反过来
样例代码
最后一维必须一致,可以理解为embeding层
from torch import nn
import torch
a = torch.randn(3,5)
b = torch.randn(2,5)
out = nn.utils.rnn.pad_sequence([a,b])
print(out)
当维度大于2时, 一般会包含batch size,所以要指定batch_size是否是第一维度
from torch import nn
import torch
a = torch.randn(4,3,5)
b = torch.randn(2,3,5)
out = nn.utils.rnn.pad_sequence([a,b], batch_first=False)
print(out)