虽然对 encoder-decoder 框架的了解已经很多了,但是从未实现过,可谓是“最熟悉的陌生人了”。近期,由于研究的需要,故而参照 github 上某开源项目(pytorch-seq2seq),实现了一个句法分析系统。本文,来学习一下实现的 decoder 部分的代码。
首先是import 部分的代码
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from attention import Attention
from baseRNN import BaseRNN
if torch.cuda.is_available():
import torch.cuda as device
else:
import torch as device
在 import部分:首先导入 numpy以及 torch 中需要使用到的模块。除了公共包,此处导入了名叫 Attention以及 BaseRNN 的模块,其中BaseRNN 为对 torch.nn.rnn 模块的一个wrapper, Attention的机制也是在 seq2seq 中一个很重要的部分,用于获取解码时对于解码中某一时刻最为 care 的信息,很简短的代码,留待以后补充。
看init部分:
def __init__(self, vocab_size, max_len, input_size, hidden_size,
sos_id, eos_id,
n_layers=1, rnn_cell='gru', bidirectional=False,
input_dropout_p=0, dropout_p=0, use_attention=False):
super(DecoderRNN, self).__init__(vocab_size, max_len, input_size, hidden_size,input_dropout_p, dropout_p,n_layers, rnn_cell)
self.bidirectional_encoder = bidirectional
self.rnn = self.rnn_cell(input_size, hidden_size, n_layers, batch_first=True, dropout=dropout_p)
self.output_size = vocab_size
self.max_length = max_len
self.use_attention = use_attention
self.eos_id = eos_id
self.sos_id = sos_id
self.init_input = None
self.embedding = nn.Embedding(self.output_size, self.input_size)
if use_attention:
self.attention = Attention(self.hidden_size)
self.fflayer = nn.Linear(self.hidden_size, self.output_size)
以上是一系列解码过程中需要使用到的参数。
- bidirectional: 指明 encoder 端的输入是否为 bidirectional,用于初始化 encoder hidden
- rnn: decoder 端为一个 rnn
- output_size: decoder端 output 的“词表”大小
- max_length: 最长解码长度
- use_attention: 是否在解码端使用注意力机制构建 feature 表示
- eos_id: 辅助用于判断解码终止
- sos_id: 辅助用于解码端的第一个输入
- init_input: 目前没什么用
- embedding: 解码端 output的词表 embedding
- fflayer: 在解码时提供计算 output 的
以下则进入我们解码时的每一步时执行的操作,即为 forward one step:
def forward_step(self, input_var, hidden, encoder_outputs, function):
"""
Args:
input_var: input token ids
hidden: last hidden state
encoder_outputs: encoder-layer output
function: probs function, default is F.log_softmax
Return:
the softmax output, the hidden state save, and the attention value
"""
batch_size = input_var.size(0)
output_size = input_var.size(1)
embedded = self.embedding(input_var)
embedded = self.input_dropout(embe