pytorch中的pack_padded_sequence和pad_packed_sequence用法

pack_padded_sequence是将句子按照batch优先的原则记录每个句子的词,变化为不定长tensor,方便计算损失函数。

pad_packed_sequence是将pack_padded_sequence生成的结构转化为原先的结构,定长的tensor。

其中test.txt的内容

As they sat in a nice coffee shop, he was too nervous to say anything
and she felt uncomfortable. Suddenly, he asked the waiter, “Could
you please give me some salt? I’d like to put it in my coffee.”

具体参见如下代码

import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import wordfreq

vocab = {}
token_id = 1
lengths = []

#读取文件,生成词典
with open('test.txt', 'r') as f:
    lines=f.readlines()
    for line in lines:
        tokens = wordfreq.tokenize(line.strip(), 'en')
        lengths.append(len(tokens))
        #将每个词加入到vocab中,并同时保存对应的index
        for word in tokens:
            if word not in vocab:
                vocab[word] = token_id
                token_id += 1

x = np.zeros((len(lengths), max(lengths)))
l_no = 0
#将词转化为数字
with open('test.txt', 'r') as f:
    lines = f.readlines()
    for line in lines:
        tokens = wordfreq.tokenize(line.strip(), 'en')
        for i in range(len(tokens)):
            x[l_no, i] = vocab[tokens[i]]
        l_no += 1

x=torch.Tensor(x)
x = Variable(x)
print(x)
'''
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.]])
'''
lengths = torch.Tensor(lengths)
print(lengths)#tensor([ 8., 11.,  5., 14.])

_, idx_sort = torch.sort(torch.Tensor(lengths), dim=0, descending=True)
print(_) #tensor([14., 11.,  8.,  5.])
print(idx_sort)#tensor([3, 1, 0, 2])

lengths = list(lengths[idx_sort])#按下标取元素 [tensor(14.), tensor(11.), tensor(8.), tensor(5.)]
t = x.index_select(0, idx_sort)#按下标取元素
print(t)
'''
tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])
'''
x_packed = nn.utils.rnn.pack_padded_sequence(input=t, lengths=lengths, batch_first=True)
print(x_packed)
'''
PackedSequence(data=tensor([24.,  9.,  1., 20., 25., 10.,  2.,  9., 26., 11.,  3., 21., 27., 12.,
         4., 22., 28., 13.,  5., 23., 29., 14.,  6., 30., 15.,  7., 31., 16.,
         8., 32., 17., 13., 18., 33., 19., 34.,  4.,  7.]), batch_sizes=tensor([4, 4, 4, 4, 4, 3, 3, 3, 2, 2, 2, 1, 1, 1]))
'''


x_padded = nn.utils.rnn.pad_packed_sequence(x_packed, batch_first=True)#x_padded是tuple
print(x_padded)
'''
(tensor([[24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]]), tensor([14, 11,  8,  5]))
'''
#还原tensor
_, idx_unsort = torch.sort(idx_sort)
output = x_padded[0].index_select(0, idx_unsort)
print(output)
'''
tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,  0.,  0.,  0.],
        [20.,  9., 21., 22., 23.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [24., 25., 26., 27., 28., 29., 30., 31., 32., 13., 33., 34.,  4.,  7.]])
'''
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值