【pytorch】nn.utils.rnn.pad_sequence的使用


错误

The size of tensor a (3) must match the size of tensor b (5) at non-singleton dimension 1

在使用nn.utils.rnn.pad_sequence时,遇到如上错误,原因是使用方式错误.

使用说明

用padding_value填充可变长度张量列表
pad_sequence 沿新维度堆叠张量列表,
并将它们垫成相等的长度。
例如,如果输入是列表
大小为“L x *”的序列,如果batch_first为False,并且“T x B x *”

“B”是批量大小。它等于“序列”中元素的数量。
“T”是最长序列的长度。
“L”是序列的长度。
“*”是任意数量的尾随维度,包括没有。

例子:
    >>> from torch.nn.utils.rnn import pad_sequence
    >>> a = torch.ones(25, 300)
    >>> b = torch.ones(22, 300)
    >>> c = torch.ones(15, 300)
    >>> pad_sequence([a, b, c]).size()
    torch.Size([25, 3, 300])

注意:
    该函数返回大小为“T x B x *”或“B x T x *”的张量
    其中“T”是最长序列的长度。该函数假设
    序列中所有张量的尾随维度和类型都是相同的。

参数:
    序列 (list[Tensor]):可变长度序列的列表。
    batch_first(bool,可选):如果为 True,输出将在“B x T x *”中,否则在
        ``T x B x *`` 否则。默认值:假。
    padding_value (float,可选):填充元素的值。默认值:0。

返回:
    如果:attr:`batch_first` 为``False``,则大小为``T x B x *`` 的张量。
    大小为“B x T x *”的张量,否则反过来

样例代码

最后一维必须一致,可以理解为embeding层

from torch import nn
import torch

a = torch.randn(3,5)
b = torch.randn(2,5)

out = nn.utils.rnn.pad_sequence([a,b])
print(out)

当维度大于2时, 一般会包含batch size,所以要指定batch_size是否是第一维度

from torch import nn
import torch

a = torch.randn(4,3,5)
b = torch.randn(2,3,5)

out = nn.utils.rnn.pad_sequence([a,b], batch_first=False)
print(out)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值