版权声明:本文为原创文章,未经博主允许不得用于商业用途。
第一个RNN程序用来练手,输入上联,输出下联,使用了seq2seq模型,如下图
(Image source: https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/)
模型说明
首先使用word-embedding对汉字重新编码到500维向量,之后经过encoderRNN和decoderRNN(双向GRU),其中decoderRNN通过Attention对encoder的最后一个隐藏层输出加权,decoderRNN的第一轮输入为句子起始符SOS。
- 模型使用GRU作为RNNCell,加入了Luong Attention,word-embedding是从随模型共同训练的。
- 由于输出长度不确定,因此引入句子终结符EOS,当decoderRNN输出EOS后就视作完成一次输出。
- 由于RNN很容易出现梯度爆炸,所以使用clipping和GRU作为Cell,不使用LSTM是为了减少参数,加速训练。
代码如下:
#双向GRU的编码器,输出为最后一个隐藏层的数据
class EncoderRNN(nn.Module):
def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
super(EncoderRNN, self).__init__()
self.n_layers = n_layers
self.hidden_size = hidden_size
self.embedding = embedding
# Initialize GRU; the input_size and hidden_size params are both set to 'hidden_size'
# because our input size is a word embedding with number of features == hidden_size
self.gru = nn.GRU(hidden_size, hidden_size, n_layers,
dropout=(0 if n_layers == 1 else dropout), bidirectional=True)
def forward(self, input_seq, input_lengths, hidden=None):
# use word-embedding to preprocess input charactors
embedded = self.embedding(input_seq)
# 转化为变长的padding
packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.gru(packed, hidden)
# Unpack padding
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs)
# 双向RNN输出直接做和作为输出
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:]
return outputs, hidden
# Luong attention layer
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
if self.method not in ['dot', 'general', 'concat']:
raise ValueError(self.method, "is not an appropriate attention method.")
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(hidden_size))
def dot_score(self, hidden, encoder_output):
return torch.sum(hidden * encoder_output, dim=2)
def general_score(self, hidden, encoder_output):
energy = self.attn(encoder_output)
return torch.sum(hidden * energy, dim=2)
def concat_score(self, hidden, encoder_output):
energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()
return torch.sum(self.v * energy, dim=2)
def forward(self, hidden, encoder_outputs):
# Calculate the attention weights (energies) based on the given method
if self.method == 'general':
attn_energies = self.general_score(hidden, encoder_outputs)
elif self.method == 'concat':
attn_energies = self.concat_score(hidden, encoder_outputs)
elif self.method == 'dot':
attn_energies = self.dot_score(hidden, encoder_outputs)
# Transpose max_length and batch_size dimensions
attn_energies = attn_energies.t()
# Return the softmax normalized probability scores (with added dimension)
return F.softmax(attn_energies, dim=1).unsqueeze(1)
#使用Luong Attention的Decoder
class LuongAttnDecoderRNN(nn.Module):
def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):
super(LuongAttnDecoderRNN, self).__init__()
# Keep for reference
self.attn_model = attn_model
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.dropout = dropout
# Define layers
self.embedding = embedding
self.embedding_dropout = nn.Dropout(dropout)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))
self.concat = nn.Linear(hidden_size * 2, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
self.attn = Attn(attn_model, hidden_size)
def forward(self, input_step, last_hidden, encoder_outputs):
# Note: we run this one step (word) at a time
# embedding SOS
embedded = self.embedding(input_step)
embedded = self.embedding_dropout(embedded)
# Forward through unidirectional GRU
rnn_output, hidden = self.gru(embedded, last_hidden)
# 计算Attention Weight
attn_weights = self.attn(rnn_output, encoder_outputs)
# 计算encoder output基于Attention Weight的加权和
context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
# 合并encoder output和GRU第一轮的输出
rnn_output = rnn_output.squeeze(0)
context = context.squeeze(1)
concat_input = torch.cat((rnn_output, context), 1)
concat_output = torch.tanh(self.concat(concat_input))
# 将word embedding 转化回字符
output = self.out(concat_output)
output = F.softmax(output, dim=1)
# Return output and final hidden state
return output, hidden
数据集
- 采用科赛上的中国对联训练集,包含77w+的对联,9000+的汉字,保险起见就不发到网上了
训练结果
RNN由于具有时序性,所以无法在GPU上很好的加速,因此迭代次数有限,Model文件夹为迭代29epoch后的模型。
以下对联为CharRNN的输出结果(由于每轮起始是GRU中的Memory为随机的,输出也具有随机性):
1
上联:<s>天<\s>
下联:<s>地<\s>
上联:<s>雨<\s>
下联:<s>烟<\s>
2
上联:<s>米饭<\s>
下联:<s>油茶<\s>
上联:<s>山花<\s>
下联:<s>野禽<\s>
3
上联:<s>鸡冠花<\s>
下联:<s>龙牙梨<\s>
上联:<s>孔夫子<\s>
下联:<s>毛小公<\s>
more
上联:<s>今天打雷下雨<\s>
下联:<s>昨日打人走人<\s>
上联:<s>狗和猫打架不分胜负<\s>
下联:<s>狼与狗进球就是高多<\s>
文字越多输出的连贯性越差,并且可能出现如下字数不相符的情况:
上联:<s>人生没有彩排,每一天都是现场直播<\s>
下联:<s>世海无多解势,众今岂来地网先争<\s>
个人理解是如果训练次数足够多可以获得更好的结果。
完整代码见github