PyTorch关于RNN序列数据的pack_pad处理

PyTorch关于RNN序列数据的pack_pad处理

在学习使用PyTorch构造RNN过程中,看到了一个HKUST的课程中关于pytorch的入门系列代码,其中有一段关于RNN序列数据的pack_pad处理看完挺有启发的。
附上:github链接

#源码+个人注解
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import numpy as np
import itertools

#返回一个包含l所有字符的列表
def flatten(l):
    return list(itertools.chain.from_iterable(l))

seqs = ['ghatmasala', 'nicela', 'chutpakodas']

#构建字符索引列表(使用了set所以字符非重复),字符索引从1开始,且按字母升序排序
vocab = ['<pad>'] + sorted(list(set(flatten(seqs))))

# 搭建模型,分别为1个embdding词嵌入模型和1个LSTM模型
embedding_size = 3
embed = nn.Embedding(len(vocab), embedding_size)
lstm = nn.LSTM(embedding_size, 5)

#转化为每个字符对应索引列表中的索引
vectorized_seqs = [[vocab.index(tok) for tok in seq]for seq in seqs]
print("vectorized_seqs", vectorized_seqs)
#输出:
#('vectorized_seqs', [[5, 6, 1, 15, 10, 1, 14, 1, 9, 1], [11, 7, 2, 4, 9, 1], [2, 6, 16, 15, 13, 1, 8, 12, 3, 1, 14]])

print([x for x in map(len, vectorized_seqs)])
#输出:
#[10, 6, 11]

# get the length of each seq in your batch
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值