Pytorch中的 pad_packed_sequence解析

参考:https://www.cnblogs.com/sbj123456789/p/9834018.html

参考的博客有一个错误: 最后在调用pad_packed_sequence的时候,没有添加batch_first=True,导致结果错误!输出的维度信息不对!!请注意!!

import torch
import torch.nn as nn
import numpy as np

input = torch.from_numpy(np.array([[1,2,3,4],[5,6,7,0],[9,3,0,0]]))
length = [4,3,2] # lengths array has to be sorted in decreasing order
result = torch.nn.utils.rnn.pack_padded_sequence(input,lengths=length,batch_first=True)

print("==============二维输出=============")
print(result)

# input = torch.randn(8,10,300)
# length = [10,10,10,10,10,10,10,10]
#注意,这里length代表每一个batch里面数据的长度,length的数量是8个,因为对应batch_first=True,也就是input的第一维是8个,也就是有8个batch
# perm = torch.LongTensor(range(8))
# result = torch.nn.utils.rnn.pack_padded_sequence(input[perm],lengths=length,batch_first=True)
# print(result)

input = torch.randn(2,3,4)
print("===========三维原始向量================")
print(input)
length = [3,3]
perm = torch.LongTensor(range(2))
result = torch.nn.utils.rnn.pack_padded_sequence(input[perm],lengths=length,batch_first=True)
print("=============三维原始向量==batchsize变换============")
print(result)

# input = torch.randn(2,3,4)
input = torch.FloatTensor([[1,2,3],[1,0,0]]).resize_(2,3,1)
print("===========三维原始向量================")
print(input)
length = [2,1,1]  # 这里的legth里面的元素代表每个batch元素的长度,这里的batchsize个数是3,所以length是3维
#同时每一个batch最大长度是2,第一个batch,其元素都是2,所以长度为2 ,第二个和第三个只有第一个元素不为0,所以长度是1
#请记住,一定是 代表每个batch元素的长度,重要的事情所三遍!!!(一个batch的组成需要考虑多个seq——length数组)
perm = torch.LongTensor(range(2))
result = torch.nn.utils.rnn.pack_padded_sequence(input[perm],lengths=length)
print("=============三维原始向量==batchsize 不等于true变换============")
print(result)

hidden_size =1
n_layers = 1

lstm=nn.LSTM(1,hidden_size,n_layers)

encoder_outputs_packed, (h_last, c_last) = lstm(result)
print("=============encoder_outputs_packed==batchsize 不等于true变换=====")
print(encoder_outputs_packed)

encoder_outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(encoder_outputs_packed)

print("=============unpacked============")
print(encoder_outputs)

 

输出:

 

batch_first=true的情形

代码更改如下:

生成结果如图:

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值