这里仅作具体代码实现,原理见https://www.cnblogs.com/sbj123456789/p/9834018.html
import torch
# =========================
# Config
# =========================
d_input = 10
d_output = 20
RNN = torch.nn.GRU(d_input, d_output)
vocab_size = 5
emb_layer = torch.nn.Embedding(vocab_size, d_input)
# =========================
# data(Batch_size=3)
# =========================
x1 = ['you', 'are', 'a', 'promise', '[PAD]']
x2 = ['are', 'you', '[PAD]', '[PAD]', '[PAD]']
x3 = ['you', 'are', 'a', '[PAD]', '[PAD]']
x_batchsize_3 = [x1, x2, x3]
...
vocab = {"[PAD]": 0, "you": 1, "are": 2, "a": 3, "promise": 4}
...
seq_id = torch.tensor([[1, 2, 3, 4, 0], [2, 1, 0, 0, 0], [1, 2, 3, 0, 0]])
# =========================
# Embedding
# =========================
input = emb_layer(seq_id)
print("input:", input.shape)
# =========================
# RNN
# =========================
"(1) 整理:要按实际句长,由长到短来从上到下排序"
lens = seq_id.count_nonzero(dim=1) + 1 # 盘点每句实际句长
sorted_len, sorted_idx = lens.sort(dim=0, descending=True) # 按上一步的得到的所有句长,按从大到小排序,同时获取对应索引号
# index_tensor拓展
index_sorted_idx = sorted_idx.view(-1, 1, 1).expand_as(input)
# 将B个(len,d)按上面说的顺序排序好 (这里gather函数就是以Batchsize为单位,每次抽取对应索引号的二维张量(len,d),实现排序的)
sorted_inputs = torch.gather(input, index=index_sorted_idx.long(), dim=0) # sort by num_words
"(2)rnn_压缩"
packed_seq = torch.nn.utils.rnn.pack_padded_sequence(
sorted_inputs, sorted_len.cpu().numpy(), batch_first=True)
"(3)rnn"
out, _ = RNN(packed_seq)
"(4)rnn_解压"
y, _ = torch.nn.utils.rnn.pad_packed_sequence(
out, batch_first=True) # currently in WRONG order!
"(5)还原回之前的句子排序"
_, original_idx = sorted_idx.sort(dim=0, descending=False) # 用作还原的索引号
unsorted_idx = original_idx.view(-1, 1, 1).expand_as(y)
output = torch.gather(y, index=unsorted_idx, dim=0).contiguous()
print("output:", output.shape)