pytorch 之pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence使用

pad_sequence

该函数用padding_value来填充一个可变长度的张量列表。将长度较短的序列填充为和最长序列相同的长度。
一句话就是:填充句子到相同长度。
参数说明:

  • sequences(list[Tensor]):变长序列的列表。
  • batch_frist(bool,optional):如果为True,output形状为B × T × ∗ ,否则为T × B × ∗ ,默认情况为False。其中B BB为批次大小,T TT为填充后每个序列的长度。
  • padding_value(float,optional):填充元素的值。默认值:0。

输出:

如果 batch_first 是 False,张量的形状为T × B × ∗ 。否则,张量的形状为B × T × ∗ 。
举个栗子:

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
p1 = pad_sequence(DATA, batch_first=True)
print(p1)

在这里插入图片描述

pack_padded_sequence

压紧(pack)一个包含可变长度的填充序列的张量,在使用pad_sequence函数进行填充的时候,产生了冗余,因此需要对其进行pack。
参数说明:

  • input(Tensor):一批量填充后的可变长度的序列。
  • lenghts(Tensor or list(int)):每个批次元素的序列长度列表。如果输入为张量形式则必须在CPU上,不能在GPU上。
  • batch_first(bool,optional):如果为True,则输入的形状为B × T × ∗,我一般将其设置为True
  • enforce_sorted(bool,optional):如果为True,则参数lenghts为按长度递减排序的序列,这样的话输入的input也需要进行排序。我一般将其设置为False。如果为False输入将被无条件地排序。
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence

content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
print(content, DATA)
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
p2 = pack_padded_sequence(p1, [7, 5, 4], batch_first=True, enforce_sorted=False)
print(p2)

在这里插入图片描述

函数对返回的结果进行填充以恢复为原来的形状。
参数说明:

  • sequence(PackedSequence):需要填充的数据。
  • batch_first(bool,optional):如果为True,输出形状为B × T × ∗ B \times T \times *B×T×∗。
  • padding_value(float,optional):填充元素的值。
  • total_lenght(int,optional):如果不是无,输出将被填充成total_lenght。

输出:

包含填充序列的张量的元组,以及包含批次中每个序列的长度列表的张量。

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence

content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
print(content, DATA)
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
p2 = pack_padded_sequence(p1, [7, 5, 4], batch_first=True, enforce_sorted=False)
print(p2)
p3 = pad_packed_sequence(p2, batch_first=True)
print(p3)

在这里插入图片描述

pack_sequence

sequences (list[Tensor]): A list of sequences of decreasing length.enforce_sorted (bool, optional): if True, checks that the input contains sequences sorted by length in a decreasing order. If False, this condition is not checked. Default: True.

from torch.nn.utils.rnn import pack_sequence
import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
print(pack_sequence([a, b, c], True))

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

<编程路上>

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值