Seq2seq(encoder + decoder)
最基础的一个seq2seq模型,参见paper “Sequence to Sequence Learning with Neural Networks”
代码如下:
# coding = utf-8
# author = 'xy'
"""
model1: encoder + decoder
we use Bi-gru as our encoder, gru as decoder, no attention
"""
import numpy as np
import torch
from torch import nn
import test_helper
class Encoder(nn.Module):
""" encode document to get state_t"""
def __init__(self, input_size, hidden_size, embedding, num_layers=1, dropout=0.2):
super(Encoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = embedding
self.num_layers = num_layers
self.dropout = dropout
self.rnn = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=True
)
def forward(self, src, src_len):
"""
:param src: indexes, tensor:(seq_len, batch_size)
:param src_len: length of index, tensor
:return: h at time t, tensor:(num_layers, batch_size, hidden_size*2)
"""
embedded = self.embedding(src)
pack = nn.utils.rnn.pack_padded_sequence(embedded, src_len)
_, state_t = self.rnn(pack, None)
return torch.cat((state_t[0:state_t.size(0):2], state_t[1:state_t.size(0):2]), 2)
class Decoder(nn.Module):
""" decode from state_t"""
def __init__(self, input_size, hidden_size, embedding, num_layers=1, dropout=0.2):
super(Decoder, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = embedding
self.num_layers = num_layers
self.dropout = dropout
self.rnn = nn.GRU(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
bidirectional=False
)
self.fc = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.Tanh(),
nn.Linear(hidden_size, embedding.num_embeddings)
)
def forward(self, tgt, state, teacher_forcing):
"""
:param tgt: indexes, tensor:(seq_len, batch_size)
:param state: state, tensor:(num_layers, batch_size, hidden_size)
:param teacher_forcing: rate, float
:return: (outputs, state),
tensor:(seq_len, batch_size, vocab_size), tensor:(num_layers, batch_size, hidden_size)
"""
flag = np.random.random() < teacher_forcing
# teacher_forcing mode, also for testing mode
if flag:
embedded = self.embedding(tgt)
outputs, state = self.rnn(embedded, state)
outputs = outputs.view(-1, self.hidden_size)
outputs = self.fc(outputs)
outputs = outputs.view(embedded.size(0), embedded.size(1), -1)
# generation mode
else:
outputs = []
embedded = self.embedding(tgt[0:1])
for i in range(tgt.size(0)):
output, state = self.rnn(embedded, state)
output = output.view(-1, self.hidden_size)
output = self.fc(output)
outputs.append(output)
_, topi = torch.topk(output, k=1, dim=1)
embedded = self.embedding(topi.transpose(0, 1))
outputs = torch.stack(outputs)
return outputs, state
class Seq2seq(nn.Module):
""" join encoder and decoder"""
def __init__(self, input_size, hidden_size, embedding, num_layers=1, dropout=0.2):
super(Seq2seq, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.embedding = embedding
self.num_layers = num_layers
self.dropout = dropout
self.fc = nn.Sequential(
nn.Linear(hidden_size*2, hidden_size),
nn.Tanh()
)
self.encoder = Encoder(
input_size=input_size,
hidden_size=hidden_size,
embedding=embedding,
num_layers=num_layers,
dropout=dropout
)
self.decoder = Decoder(
input_size=input_size,
hidden_size=hidden_size,
embedding=embedding,
num_layers=num_layers,
dropout=dropout
)
def forward(self, src, src_len, tgt, teacher_forcing):
"""
:param src: indexes, tensor:(seq_len, batch_size)
:param src_len: length of index, tensor
:param tgt: indexes, tensor:(seq_len, batch_size)
:param teacher_forcing: rate, float
:return: (outputs, state),
tensor:(seq_len, batch_size, vocab_size), tensor:(num_layers, batch_size, hidden_size)
"""
# encode
state = self.encoder(src, src_len)
# concat
state = state.view(-1, self.hidden_size*2)
state = self.fc(state)
state = state.view(self.num_layers, -1, self.hidden_size)
# decode
outputs, state = self.decoder(tgt, state, teacher_forcing)
return outputs, state
def gen(self, index, num_beams, max_len):
"""
test mode
:param index: a sample about src, tensor
:param num_beams: .
:param max_len: max length of result
:return: result, list about index
"""
src = index.unsqueeze(1)
src_len = torch.LongTensor([src.size(0)])
# encode
state = self.encoder(src, src_len)
state = state.view(-1, self.hidden_size*2)
state = self.fc(state)
state = state.view(self.num_layers, -1, self.hidden_size)
# decode
result = test_helper.beam_search(self.decoder, num_beams, max_len, state)
if result[-1] == 2:
return result[1: -1]
else:
return result[1:]