RNN之pack_padded_sequence()和pad_packed_sequence()具体使用代码完整展现

该博客介绍了如何在PyTorch中利用GRU(门控循环单元)处理变长序列数据。首先配置输入和输出维度,创建GRU和嵌入层。接着,对输入数据进行预处理,包括填充和排序。然后,通过`pack_padded_sequence`将排序后的序列打包,供GRU处理。最后,通过`pad_packed_sequence`将结果解包并恢复原始顺序。整个过程详细展示了在处理变长序列时的数据处理流程。
摘要由CSDN通过智能技术生成

 这里仅作具体代码实现,原理见https://www.cnblogs.com/sbj123456789/p/9834018.html

import torch
# =========================
# Config
# =========================
d_input = 10
d_output = 20
RNN = torch.nn.GRU(d_input, d_output)

vocab_size = 5
emb_layer = torch.nn.Embedding(vocab_size, d_input)

# =========================
# data(Batch_size=3)
# =========================
x1 = ['you', 'are', 'a', 'promise', '[PAD]']
x2 = ['are', 'you', '[PAD]', '[PAD]', '[PAD]']
x3 = ['you', 'are', 'a', '[PAD]', '[PAD]']
x_batchsize_3 = [x1, x2, x3]
...
vocab = {"[PAD]": 0, "you": 1, "are": 2, "a": 3, "promise": 4}
...
seq_id = torch.tensor([[1, 2, 3, 4, 0], [2, 1, 0, 0, 0], [1, 2, 3, 0, 0]])

# =========================
# Embedding
# =========================

input = emb_layer(seq_id)
print("input:", input.shape)
# =========================
# RNN
# =========================
"(1) 整理:要按实际句长,由长到短来从上到下排序"
lens = seq_id.count_nonzero(dim=1) + 1  # 盘点每句实际句长
sorted_len, sorted_idx = lens.sort(dim=0, descending=True)  # 按上一步的得到的所有句长,按从大到小排序,同时获取对应索引号

# index_tensor拓展
index_sorted_idx = sorted_idx.view(-1, 1, 1).expand_as(input)
# 将B个(len,d)按上面说的顺序排序好 (这里gather函数就是以Batchsize为单位,每次抽取对应索引号的二维张量(len,d),实现排序的)
sorted_inputs = torch.gather(input, index=index_sorted_idx.long(), dim=0)  # sort by num_words

"(2)rnn_压缩"
packed_seq = torch.nn.utils.rnn.pack_padded_sequence(
                sorted_inputs, sorted_len.cpu().numpy(), batch_first=True)
"(3)rnn"
out, _ = RNN(packed_seq)

"(4)rnn_解压"
y, _ = torch.nn.utils.rnn.pad_packed_sequence(
                out, batch_first=True)  # currently in WRONG order!

"(5)还原回之前的句子排序"
_, original_idx = sorted_idx.sort(dim=0, descending=False)  # 用作还原的索引号
unsorted_idx = original_idx.view(-1, 1, 1).expand_as(y)

output = torch.gather(y, index=unsorted_idx, dim=0).contiguous()
print("output:", output.shape)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值