题目
'''
Description: seq2seq
Author: 365JHWZGo
Date: 2021-12-16 19:59:38
LastEditors: 365JHWZGo
LastEditTime: 2021-12-16 20:14:46
'''
Seq2Seq介绍
特点
- 输入N个数据,输出M个数据
- 由解码器和译码器组成,被称为Seq2Seq
- 实际上是由N-1和1-N模型组成
使用场景
- 机器阅读
- 机器翻译
示意图
举一个具体的例子,假如要做翻译任务,将hello world翻译成法语
output:
每一个时间刻最后一层的状态
hidden_state
在最后一个时刻所有层的状态
encoder可以用传统的rnn,或者使用性能更好的lstm、gru(也可以使用双向rnn)
decoder的输入是前面encoder的隐藏层的输出和翻译的开始标志
下一个的输入为上一个的翻译内容和上一次翻译的隐藏状态
代码
from torch import nn
import torch
import numpy as np
import utils
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy,softmax
class Seq2Seq(nn.Module):
def __init__(self,enc_v_dim, dec_v_dim, emb_dim, units, max_pred_len, start_token, end_token):
super().__init__()
self.units = units
self.dec_v_dim = dec_v_dim
# encoder
self.enc_embeddings = nn.Embedding(enc_v_dim,emb_dim)
self.enc_embeddings.weight.data.normal_(0,0.1)
self.encoder = nn.LSTM(emb_dim,units,1,batch_first=True)
# decoder
self.dec_embeddings = nn.Embedding(dec_v_dim,emb_dim)
self.dec_embeddings.weight.data.normal_(0,0.1)
self.decoder_cell = nn.LSTMCell(emb_dim,units)
self.decoder_dense = nn.Linear(units,dec_v_dim)
self.opt = torch.optim.Adam(self.parameters(),lr=0.001)
self.max_pred_len = max_pred_len
self.start_token = start_token
self.end_token = end_token
def encode(self,x):
embedded = self.enc_embeddings(x) # [n, step, emb]
hidden = (torch.zeros(1,x.shape[0],self.units),torch.zeros(1,x.shape[0],self.units))
o,(h,c) = self.encoder(embedded,hidden)
return h,c
def inference(self,x):
self.eval()
hx,cx = self.encode(x)
hx,cx = hx[0],cx[0]
start = torch.ones(x.shape[0],1)
start[:,0] = torch.tensor(self.start_token)
start= start.type(torch.LongTensor)
dec_emb_in = self.dec_embeddings(start)
dec_emb_in = dec_emb_in.permute(1,0,2)
dec_in = dec_emb_in[0]
output = []
for i in range(self.max_pred_len):
hx, cx = self.decoder_cell(dec_in, (hx, cx))
o = self.decoder_dense(hx)
o = o.argmax(dim=1).view(-1,1)
dec_in=self.dec_embeddings(o).permute(1,0,2)[0]
output.append(o)
output = torch.stack(output,dim=0)
self.train()
return output.permute(1,0,2).view(-1,self.max_pred_len)
def train_logit(self,x,y):
hx,cx = self.encode(x)
hx,cx = hx[0],cx[0]
dec_in = y[:,:-1]
dec_emb_in = self.dec_embeddings(dec_in)
dec_emb_in = dec_emb_in.permute(1,0,2)
output = []
for i in range(dec_emb_in.shape[0]):
hx, cx = self.decoder_cell(dec_emb_in[i], (hx, cx))
o = self.decoder_dense(hx)
output.append(o)
output = torch.stack(output,dim=0)
return output.permute(1,0,2)
def step(self,x,y):
self.opt.zero_grad()
batch_size = x.shape[0]
logit = self.train_logit(x,y)
dec_out = y[:,1:]
loss = cross_entropy(logit.reshape(-1,self.dec_v_dim),dec_out.reshape(-1))
loss.backward()
self.opt.step()
return loss.detach().numpy()
def train():
dataset = utils.DateData(4000)
print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
print("Vocabularies: ", dataset.vocab)
print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
loader = DataLoader(dataset,batch_size=32,shuffle=True)
model = Seq2Seq(dataset.num_word,dataset.num_word,emb_dim=16,units=32,max_pred_len=11,start_token=dataset.start_token,end_token=dataset.end_token)
for i in range(40):
for batch_idx , batch in enumerate(loader):
bx, by, decoder_len = batch
bx = bx.type(torch.LongTensor)
by = by.type(torch.LongTensor)
loss = model.step(bx,by)
if batch_idx % 70 == 0:
target = dataset.idx2str(by[0, 1:-1].data.numpy())
pred = model.inference(bx[0:1])
res = dataset.idx2str(pred[0].data.numpy())
src = dataset.idx2str(bx[0].data.numpy())
print(
"Epoch: ",i,
"| t: ", batch_idx,
"| loss: %.3f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)
if __name__ == "__main__":
train()
运行结果
Chinese time order: yy/mm/dd ['31-04-26', '04-07-18', '33-06-06']
English time order: dd/M/yyyy ['26/Apr/2031', '18/Jul/2004', '06/Jun/2033']
Vocabularies: {'-', '<PAD>', '5', 'Nov', 'Aug', 'Feb', '9', 'Oct', '4', '3', 'Sep', 'Apr', '<GO>', 'Jun', 'Jan', '/', 'Jul', '7', '1', '8', '<EOS>', '0', 'Mar', 'Dec', 'May', '2', '6'}
x index sample:
31-04-26
[6 4 1 3 7 1 5 9]
y index sample:
<GO>26/Apr/2031<EOS>
[14 5 9 2 15 2 5 3 6 4 13]
Epoch: 0 | t: 0 | loss: 3.311 | input: 16-01-26 | target: 26/Jan/2016 | inference: 00000000000
Epoch: 0 | t: 70 | loss: 2.533 | input: 86-02-24 | target: 24/Feb/1986 | inference: 2//////////
Chinese time order: yy/mm/dd ['31-04-26', '04-07-18', '33-06-06']
English time order: dd/M/yyyy ['26/Apr/2031', '18/Jul/2004', '06/Jun/2033']
Vocabularies: {'Sep', '/', '<EOS>', '-', 'Aug', '3', '1', '<GO>', 'Feb', 'Jun', '7', 'Oct', 'May', '8', 'Dec', 'Nov', '6', 'Jul', '5', '0', '9', '<PAD>', 'Apr', '2', 'Mar', 'Jan', '4'}
x index sample:
31-04-26
[6 4 1 3 7 1 5 9]
y index sample:
<GO>26/Apr/2031<EOS>
[14 5 9 2 15 2 5 3 6 4 13]
Epoch: 0 | t: 0 | loss: 3.293 | input: 85-06-18 | target: 18/Jun/1985 | inference: 4<GO><GO><GO><GO><GO><GO><GO><GO><GO><GO>
Epoch: 0 | t: 70 | loss: 2.502 | input: 90-10-26 | target: 26/Oct/1990 | inference: 2//////////
Epoch: 1 | t: 0 | loss: 2.141 | input: 04-01-08 | target: 08/Jan/2004 | inference: 1/////00<EOS>
Epoch: 1 | t: 70 | loss: 1.791 | input: 98-04-26 | target: 26/Apr/1998 | inference: 1////200<EOS>
Epoch: 2 | t: 0 | loss: 1.536 | input: 23-04-09 | target: 09/Apr/2023 | inference: 02//20000<EOS>
Epoch: 2 | t: 70 | loss: 1.295 | input: 24-11-24 | target: 24/Nov/2024 | inference: 11//2000<EOS>
Epoch: 3 | t: 0 | loss: 1.218 | input: 28-05-28 | target: 28/May/2028 | inference: 12//2000<EOS>
Epoch: 3 | t: 70 | loss: 1.153 | input: 21-03-07 | target: 07/Mar/2021 | inference: 22/Mar/2002<EOS>
Epoch: 4 | t: 0 | loss: 1.123 | input: 27-08-25 | target: 25/Aug/2027 | inference: 12/Jan/2017<EOS>
Epoch: 4 | t: 70 | loss: 1.055 | input: 08-06-19 | target: 19/Jun/2008 | inference: 01/Mar/2018<EOS>
Epoch: 5 | t: 0 | loss: 1.035 | input: 89-04-20 | target: 20/Apr/1989 | inference: 12/Mar/2018<EOS>
Epoch: 5 | t: 70 | loss: 0.998 | input: 18-11-12 | target: 12/Nov/2018 | inference: 22/Jan/2003<EOS>
Epoch: 6 | t: 0 | loss: 0.971 | input: 74-12-25 | target: 25/Dec/1974 | inference: 26/Mar/198<EOS>
Epoch: 6 | t: 70 | loss: 0.932 | input: 81-08-19 | target: 19/Aug/1981 | inference: 19/Jan/2015<EOS>
Epoch: 7 | t: 0 | loss: 0.903 | input: 06-03-09 | target: 09/Mar/2006 | inference: 19/Jul/2017<EOS>
Epoch: 7 | t: 70 | loss: 0.860 | input: 89-04-08 | target: 08/Apr/1989 | inference: 09/Mar/1998<EOS>
Epoch: 8 | t: 0 | loss: 0.842 | input: 25-12-23 | target: 23/Dec/2025 | inference: 22/Jan/2024<EOS>
Epoch: 8 | t: 70 | loss: 0.780 | input: 98-01-12 | target: 12/Jan/1998 | inference: 21/Jul/1997<EOS>
Epoch: 9 | t: 0 | loss: 0.763 | input: 32-06-24 | target: 24/Jun/2032 | inference: 26/Mar/2024<EOS>
Epoch: 9 | t: 70 | loss: 0.718 | input: 76-07-27 | target: 27/Jul/1976 | inference: 27/May/198<EOS>
Epoch: 10 | t: 0 | loss: 0.700 | input: 89-12-19 | target: 19/Dec/1989 | inference: 19/Jan/1997<EOS>
Epoch: 10 | t: 70 | loss: 0.670 | input: 78-06-15 | target: 15/Jun/1978 | inference: 16/May/1994<EOS>
Epoch: 11 | t: 0 | loss: 0.659 | input: 13-05-16 | target: 16/May/2013 | inference: 16/Mar/2017<EOS>
Epoch: 11 | t: 70 | loss: 0.619 | input: 88-05-31 | target: 31/May/1988 | inference: 12/Jul/1997<EOS>
Epoch: 12 | t: 0 | loss: 0.585 | input: 08-09-09 | target: 09/Sep/2008 | inference: 09/Jan/2007<EOS>
Epoch: 12 | t: 70 | loss: 0.572 | input: 93-09-30 | target: 30/Sep/1993 | inference: 20/Jan/1997<EOS>
Epoch: 13 | t: 0 | loss: 0.557 | input: 75-10-15 | target: 15/Oct/1975 | inference: 14/May/1976<EOS>
Epoch: 13 | t: 70 | loss: 0.539 | input: 25-02-03 | target: 03/Feb/2025 | inference: 03/Mar/2026<EOS>
Epoch: 14 | t: 0 | loss: 0.526 | input: 99-09-22 | target: 22/Sep/1999 | inference: 22/Mar/1999<EOS>
Epoch: 14 | t: 70 | loss: 0.500 | input: 08-09-11 | target: 11/Sep/2008 | inference: 11/Jan/2007<EOS>
Epoch: 15 | t: 0 | loss: 0.495 | input: 32-02-05 | target: 05/Feb/2032 | inference: 04/Mar/2027<EOS>
Epoch: 15 | t: 70 | loss: 0.489 | input: 11-04-09 | target: 09/Apr/2011 | inference: 09/Sep/2019<EOS>
Epoch: 16 | t: 0 | loss: 0.475 | input: 03-07-19 | target: 19/Jul/2003 | inference: 19/Jan/2001<EOS>
Epoch: 16 | t: 70 | loss: 0.458 | input: 19-07-12 | target: 12/Jul/2019 | inference: 12/Jan/2019<EOS>
Epoch: 17 | t: 0 | loss: 0.448 | input: 23-07-26 | target: 26/Jul/2023 | inference: 26/Mar/2023<EOS>
Epoch: 17 | t: 70 | loss: 0.433 | input: 74-12-29 | target: 29/Dec/1974 | inference: 29/Jan/1976<EOS>
Epoch: 18 | t: 0 | loss: 0.421 | input: 27-09-16 | target: 16/Sep/2027 | inference: 16/Jan/2027<EOS>
Epoch: 18 | t: 70 | loss: 0.389 | input: 29-05-26 | target: 26/May/2029 | inference: 26/Mar/2023<EOS>
Epoch: 19 | t: 0 | loss: 0.388 | input: 02-06-11 | target: 11/Jun/2002 | inference: 11/Sep/2002<EOS>
Epoch: 19 | t: 70 | loss: 0.367 | input: 26-06-22 | target: 22/Jun/2026 | inference: 22/Mar/2026<EOS>
Epoch: 20 | t: 0 | loss: 0.362 | input: 86-06-08 | target: 08/Jun/1986 | inference: 08/Aug/1984<EOS>
Epoch: 20 | t: 70 | loss: 0.347 | input: 26-05-08 | target: 08/May/2026 | inference: 08/Sep/2026<EOS>
Epoch: 21 | t: 0 | loss: 0.338 | input: 99-08-09 | target: 09/Aug/1999 | inference: 09/Aug/1999<EOS>
Epoch: 21 | t: 70 | loss: 0.319 | input: 13-12-11 | target: 11/Dec/2013 | inference: 11/Nov/2013<EOS>
Epoch: 22 | t: 0 | loss: 0.315 | input: 91-12-11 | target: 11/Dec/1991 | inference: 11/Nov/1991<EOS>
Epoch: 22 | t: 70 | loss: 0.295 | input: 10-12-31 | target: 31/Dec/2010 | inference: 31/Nov/2011<EOS>
Epoch: 23 | t: 0 | loss: 0.269 | input: 30-11-29 | target: 29/Nov/2030 | inference: 29/Nov/2030<EOS>
Epoch: 23 | t: 70 | loss: 0.268 | input: 17-02-08 | target: 08/Feb/2017 | inference: 08/Feb/2017<EOS>
Epoch: 24 | t: 0 | loss: 0.259 | input: 78-08-21 | target: 21/Aug/1978 | inference: 21/Aug/1978<EOS>
Epoch: 24 | t: 70 | loss: 0.234 | input: 32-12-01 | target: 01/Dec/2032 | inference: 01/Dec/2032<EOS>
Epoch: 25 | t: 0 | loss: 0.245 | input: 83-04-30 | target: 30/Apr/1983 | inference: 30/Mar/1983<EOS>
Epoch: 25 | t: 70 | loss: 0.226 | input: 31-10-09 | target: 09/Oct/2031 | inference: 09/Nov/2031<EOS>
Epoch: 26 | t: 0 | loss: 0.222 | input: 15-04-09 | target: 09/Apr/2015 | inference: 09/May/2015<EOS>
Epoch: 26 | t: 70 | loss: 0.200 | input: 13-12-22 | target: 22/Dec/2013 | inference: 22/Dec/2013<EOS>
Epoch: 27 | t: 0 | loss: 0.200 | input: 21-03-09 | target: 09/Mar/2021 | inference: 09/Feb/2021<EOS>
Epoch: 27 | t: 70 | loss: 0.182 | input: 14-05-26 | target: 26/May/2014 | inference: 26/May/2014<EOS>
Epoch: 28 | t: 0 | loss: 0.185 | input: 20-12-27 | target: 27/Dec/2020 | inference: 27/Dec/2020<EOS>
Epoch: 28 | t: 70 | loss: 0.173 | input: 02-05-23 | target: 23/May/2002 | inference: 23/May/2002<EOS>
Epoch: 29 | t: 0 | loss: 0.169 | input: 92-12-28 | target: 28/Dec/1992 | inference: 28/Dec/1992<EOS>
Epoch: 29 | t: 70 | loss: 0.156 | input: 25-08-28 | target: 28/Aug/2025 | inference: 28/Sep/2025<EOS>
Epoch: 30 | t: 0 | loss: 0.154 | input: 95-02-21 | target: 21/Feb/1995 | inference: 21/Feb/1995<EOS>
Epoch: 30 | t: 70 | loss: 0.168 | input: 84-11-11 | target: 11/Nov/1984 | inference: 11/Nov/1984<EOS>
Epoch: 31 | t: 0 | loss: 0.143 | input: 28-11-14 | target: 14/Nov/2028 | inference: 14/Nov/2028<EOS>
Epoch: 31 | t: 70 | loss: 0.137 | input: 18-02-02 | target: 02/Feb/2018 | inference: 02/Feb/2018<EOS>
Epoch: 32 | t: 0 | loss: 0.121 | input: 91-05-11 | target: 11/May/1991 | inference: 11/May/1991<EOS>
Epoch: 32 | t: 70 | loss: 0.132 | input: 14-03-19 | target: 19/Mar/2014 | inference: 19/Feb/2014<EOS>
Epoch: 33 | t: 0 | loss: 0.123 | input: 24-12-04 | target: 04/Dec/2024 | inference: 04/Dec/2024<EOS>
Epoch: 33 | t: 70 | loss: 0.115 | input: 24-07-22 | target: 22/Jul/2024 | inference: 22/Mar/2024<EOS>
Epoch: 34 | t: 0 | loss: 0.124 | input: 27-09-04 | target: 04/Sep/2027 | inference: 04/Sep/2027<EOS>
Epoch: 34 | t: 70 | loss: 0.113 | input: 80-08-29 | target: 29/Aug/1980 | inference: 29/Aug/1980<EOS>
Epoch: 35 | t: 0 | loss: 0.097 | input: 13-02-12 | target: 12/Feb/2013 | inference: 12/Feb/2013<EOS>
Epoch: 35 | t: 70 | loss: 0.097 | input: 80-12-13 | target: 13/Dec/1980 | inference: 13/Dec/1980<EOS>
Epoch: 36 | t: 0 | loss: 0.098 | input: 27-02-04 | target: 04/Feb/2027 | inference: 04/Feb/2027<EOS>
Epoch: 36 | t: 70 | loss: 0.098 | input: 12-12-29 | target: 29/Dec/2012 | inference: 29/Dec/2012<EOS>
Epoch: 37 | t: 0 | loss: 0.097 | input: 94-12-01 | target: 01/Dec/1994 | inference: 01/Dec/1994<EOS>
Epoch: 37 | t: 70 | loss: 0.081 | input: 30-11-29 | target: 29/Nov/2030 | inference: 29/Nov/2030<EOS>
Epoch: 38 | t: 0 | loss: 0.077 | input: 05-03-03 | target: 03/Mar/2005 | inference: 03/Mar/2005<EOS>
Epoch: 38 | t: 70 | loss: 0.094 | input: 78-07-31 | target: 31/Jul/1978 | inference: 31/Mar/1978<EOS>
Epoch: 39 | t: 0 | loss: 0.074 | input: 01-12-04 | target: 04/Dec/2001 | inference: 04/Dec/2001<EOS>
Epoch: 39 | t: 70 | loss: 0.072 | input: 88-10-03 | target: 03/Oct/1988 | inference: 03/Oct/1988<EOS>
总结
通过上面的理解我们知道了,Encoder负责理解上文,Decoder负责将思考怎么样在理解的句子的基础上做任务
Decoder
是解压器,但是它并不是将压缩好的信息还原,而是解压成另外一种形式,换一种表达方式