引用自: 刘二大人 《PyTorch深度学习实践》
传送门: 刘二大人 《PyTorch深度学习实践》
传送门2:数据集
一、实现功能
! ! ! 给 出 一 个 名 字 n a m e , 找 到 它 对 应 的 语 言 l a n g u a g e / 国 家 c o u n t r y \color {RED} {!!! \ 给出一个名字name,找到它对应的语言language/国家country} !!! 给出一个名字name,找到它对应的语言language/国家country
![](https://i-blog.csdnimg.cn/blog_migrate/214c9c74de21e1f08f9ee284e82fa304.png)
共有来自18种语言的names。预测一个name属于哪种语言/国家。
二、模型整体架构
![](https://i-blog.csdnimg.cn/blog_migrate/224e23a982370b7e02f240b416f37332.png)
![](https://i-blog.csdnimg.cn/blog_migrate/ef094ff1e1e3380599fc9705726f688f.png)
三、亿点点准备知识和实现细节
3.1 Bi-directional GRU/LSTM/RNN
![](https://i-blog.csdnimg.cn/blog_migrate/992894474ab0ca03ceeebb040448eae2.png)
![](https://i-blog.csdnimg.cn/blog_migrate/3f0e09763a3f184743e84754c9f58a3b.png)
双 向 结 构 的 R N N 中 : \color {orange} {双向结构的RNN中:} 双向结构的RNN中:
最 终 h i d d e n 是 由 两 个 方 向 的 第 n 个 h i d d e n 拼 接 而 成 的 \color {orange} {最终hidden是由两个方向的第n个hidden拼接而成的} 最终hidden是由两个方向的第n个hidden拼接而成的
output, hidden = self.gru(gru_input, hidden)
if self.num_directions == 2:
hidden_cat = torch.cat([hidden[-1], hidden[-2]], dim = 1) # GRU为双向时,hidden = [前向的第n个hidden, 反向的第n个hidden] 连接
else:
hidden_cat = hidden[-1] # GRU为单向时,hidden = 前向的第n个hidden
3.2 对输入的name的转置处理
input = input.t() # 将input shape由BatchSize * SeqLen -> SeqLen * BatchSize
![](https://i-blog.csdnimg.cn/blog_migrate/e970976d4cc6c10f1cf4ce41e5b55283.png)
3.3 RNNClassifier类中forward()方法中的embedding
e m b e d d i n g 的 s h a p e : ( s e q L e n , b a t c h S i z e , h i d d e n S i z e ) \color {green} {embedding的shape: (seqLen, batchSize, hiddenSize)} embedding的shape:(seqLen,batchSize,hiddenSize)
embedding = self.embedding(input)
![](https://i-blog.csdnimg.cn/blog_migrate/e73c3e15fcd157af3698f7db73899773.png)
3.4 forward()方法中的pack_padded_sequence()方法
! ! ! 新 的 重 要 知 识 点 \color {tomato} {!!!新的重要知识点} !!!新的重要知识点
使用前需要from torch.nn.utils.rnn import pack_padded_sequence
返回一个PackedSequence对象
第1个参数的shape:(seqLen, batchSize, hiddenSize)
第2个参数是一个tensor, 它是每个batch element的序列长度的列表
gru_input = pack_padded_sequence(embedding, seq_lengths)
![](https://i-blog.csdnimg.cn/blog_migrate/2cfd1754bfdd64133076b2eadbfacf1d.png)