利用seq2seq框架做翻译任务

题目

'''
Description: seq2seq
Author: 365JHWZGo
Date: 2021-12-16 19:59:38
LastEditors: 365JHWZGo
LastEditTime: 2021-12-16 20:14:46
'''

Seq2Seq介绍

特点
  • 输入N个数据,输出M个数据
  • 由解码器和译码器组成,被称为Seq2Seq
  • 实际上是由N-1和1-N模型组成
使用场景
  • 机器阅读
  • 机器翻译
示意图

举一个具体的例子,假如要做翻译任务,将hello world翻译成法语

在这里插入图片描述

output:每一个时间刻最后一层的状态
hidden_state在最后一个时刻所有层的状态

encoder可以用传统的rnn,或者使用性能更好的lstm、gru(也可以使用双向rnn)
decoder的输入是前面encoder的隐藏层的输出和翻译的开始标志
下一个的输入为上一个的翻译内容和上一次翻译的隐藏状态

代码

from torch import nn
import torch
import numpy as np
import utils
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy,softmax

class Seq2Seq(nn.Module):
    def __init__(self,enc_v_dim, dec_v_dim, emb_dim, units, max_pred_len, start_token, end_token):
        super().__init__()
        self.units = units
        self.dec_v_dim = dec_v_dim

        # encoder
        self.enc_embeddings = nn.Embedding(enc_v_dim,emb_dim)
        self.enc_embeddings.weight.data.normal_(0,0.1)
        self.encoder = nn.LSTM(emb_dim,units,1,batch_first=True)
    

        # decoder
        self.dec_embeddings = nn.Embedding(dec_v_dim,emb_dim)
        self.dec_embeddings.weight.data.normal_(0,0.1)
        self.decoder_cell = nn.LSTMCell(emb_dim,units)
        self.decoder_dense = nn.Linear(units,dec_v_dim)

        self.opt = torch.optim.Adam(self.parameters(),lr=0.001)
        self.max_pred_len = max_pred_len
        self.start_token = start_token
        self.end_token = end_token

    
    def encode(self,x):
        embedded = self.enc_embeddings(x)   # [n, step, emb]
        hidden = (torch.zeros(1,x.shape[0],self.units),torch.zeros(1,x.shape[0],self.units))
        o,(h,c) = self.encoder(embedded,hidden)
        return h,c
    
    def inference(self,x):
        self.eval()
        hx,cx = self.encode(x)
        hx,cx = hx[0],cx[0]
        start = torch.ones(x.shape[0],1)
        start[:,0] = torch.tensor(self.start_token)
        start= start.type(torch.LongTensor)
        dec_emb_in = self.dec_embeddings(start)
        dec_emb_in = dec_emb_in.permute(1,0,2)
        dec_in = dec_emb_in[0]
        output = []
        for i in range(self.max_pred_len):
            hx, cx = self.decoder_cell(dec_in, (hx, cx))
            o = self.decoder_dense(hx)
            o = o.argmax(dim=1).view(-1,1)
            dec_in=self.dec_embeddings(o).permute(1,0,2)[0]
            output.append(o)
        output = torch.stack(output,dim=0)
        self.train()

        return output.permute(1,0,2).view(-1,self.max_pred_len)

    
    def train_logit(self,x,y):
        hx,cx = self.encode(x)
        hx,cx = hx[0],cx[0]
        dec_in = y[:,:-1]
        dec_emb_in = self.dec_embeddings(dec_in)
        dec_emb_in = dec_emb_in.permute(1,0,2)
        output = []
        for i in range(dec_emb_in.shape[0]):
            hx, cx = self.decoder_cell(dec_emb_in[i], (hx, cx))
            o = self.decoder_dense(hx)
            output.append(o)
        output = torch.stack(output,dim=0)
        return output.permute(1,0,2)
    
    def step(self,x,y):
        self.opt.zero_grad()
        batch_size = x.shape[0]
        logit = self.train_logit(x,y)    
        dec_out = y[:,1:]
        loss = cross_entropy(logit.reshape(-1,self.dec_v_dim),dec_out.reshape(-1))
        loss.backward()
        self.opt.step()
        return loss.detach().numpy()

def train():
    dataset = utils.DateData(4000)
    print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
    print("Vocabularies: ", dataset.vocab)
    print(f"x index sample:  \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
    f"\ny index sample:  \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
    loader = DataLoader(dataset,batch_size=32,shuffle=True)
    model = Seq2Seq(dataset.num_word,dataset.num_word,emb_dim=16,units=32,max_pred_len=11,start_token=dataset.start_token,end_token=dataset.end_token)
    for i in range(40):
        for batch_idx , batch in enumerate(loader):
            bx, by, decoder_len = batch
            bx = bx.type(torch.LongTensor)
            by = by.type(torch.LongTensor)
            loss = model.step(bx,by)
            if batch_idx % 70 == 0:
                target = dataset.idx2str(by[0, 1:-1].data.numpy())
                pred = model.inference(bx[0:1])
                res = dataset.idx2str(pred[0].data.numpy())
                src = dataset.idx2str(bx[0].data.numpy())
                print(
                    "Epoch: ",i,
                    "| t: ", batch_idx,
                    "| loss: %.3f" % loss,
                    "| input: ", src,
                    "| target: ", target,
                    "| inference: ", res,
                )


if __name__ == "__main__":
    train()

运行结果

Chinese time order: yy/mm/dd  ['31-04-26', '04-07-18', '33-06-06'] 
English time order: dd/M/yyyy ['26/Apr/2031', '18/Jul/2004', '06/Jun/2033']
Vocabularies:  {'-', '<PAD>', '5', 'Nov', 'Aug', 'Feb', '9', 'Oct', '4', '3', 'Sep', 'Apr', '<GO>', 'Jun', 'Jan', '/', 'Jul', '7', '1', '8', '<EOS>', '0', 'Mar', 'Dec', 'May', '2', '6'}
x index sample:  
31-04-26
[6 4 1 3 7 1 5 9] 
y index sample:  
<GO>26/Apr/2031<EOS>
[14  5  9  2 15  2  5  3  6  4 13]
Epoch:  0 | t:  0 | loss: 3.311 | input:  16-01-26 | target:  26/Jan/2016 | inference:  00000000000
Epoch:  0 | t:  70 | loss: 2.533 | input:  86-02-24 | target:  24/Feb/1986 | inference:  2//////////
Chinese time order: yy/mm/dd  ['31-04-26', '04-07-18', '33-06-06'] 
English time order: dd/M/yyyy ['26/Apr/2031', '18/Jul/2004', '06/Jun/2033']
Vocabularies:  {'Sep', '/', '<EOS>', '-', 'Aug', '3', '1', '<GO>', 'Feb', 'Jun', '7', 'Oct', 'May', '8', 'Dec', 'Nov', '6', 'Jul', '5', '0', '9', '<PAD>', 'Apr', '2', 'Mar', 'Jan', '4'}
x index sample:  
31-04-26
[6 4 1 3 7 1 5 9] 
y index sample:  
<GO>26/Apr/2031<EOS>
[14  5  9  2 15  2  5  3  6  4 13]
Epoch:  0 | t:  0 | loss: 3.293 | input:  85-06-18 | target:  18/Jun/1985 | inference:  4<GO><GO><GO><GO><GO><GO><GO><GO><GO><GO>
Epoch:  0 | t:  70 | loss: 2.502 | input:  90-10-26 | target:  26/Oct/1990 | inference:  2//////////
Epoch:  1 | t:  0 | loss: 2.141 | input:  04-01-08 | target:  08/Jan/2004 | inference:  1/////00<EOS>
Epoch:  1 | t:  70 | loss: 1.791 | input:  98-04-26 | target:  26/Apr/1998 | inference:  1////200<EOS>
Epoch:  2 | t:  0 | loss: 1.536 | input:  23-04-09 | target:  09/Apr/2023 | inference:  02//20000<EOS>
Epoch:  2 | t:  70 | loss: 1.295 | input:  24-11-24 | target:  24/Nov/2024 | inference:  11//2000<EOS>
Epoch:  3 | t:  0 | loss: 1.218 | input:  28-05-28 | target:  28/May/2028 | inference:  12//2000<EOS>
Epoch:  3 | t:  70 | loss: 1.153 | input:  21-03-07 | target:  07/Mar/2021 | inference:  22/Mar/2002<EOS>
Epoch:  4 | t:  0 | loss: 1.123 | input:  27-08-25 | target:  25/Aug/2027 | inference:  12/Jan/2017<EOS>
Epoch:  4 | t:  70 | loss: 1.055 | input:  08-06-19 | target:  19/Jun/2008 | inference:  01/Mar/2018<EOS>
Epoch:  5 | t:  0 | loss: 1.035 | input:  89-04-20 | target:  20/Apr/1989 | inference:  12/Mar/2018<EOS>
Epoch:  5 | t:  70 | loss: 0.998 | input:  18-11-12 | target:  12/Nov/2018 | inference:  22/Jan/2003<EOS>
Epoch:  6 | t:  0 | loss: 0.971 | input:  74-12-25 | target:  25/Dec/1974 | inference:  26/Mar/198<EOS>
Epoch:  6 | t:  70 | loss: 0.932 | input:  81-08-19 | target:  19/Aug/1981 | inference:  19/Jan/2015<EOS>
Epoch:  7 | t:  0 | loss: 0.903 | input:  06-03-09 | target:  09/Mar/2006 | inference:  19/Jul/2017<EOS>
Epoch:  7 | t:  70 | loss: 0.860 | input:  89-04-08 | target:  08/Apr/1989 | inference:  09/Mar/1998<EOS>
Epoch:  8 | t:  0 | loss: 0.842 | input:  25-12-23 | target:  23/Dec/2025 | inference:  22/Jan/2024<EOS>
Epoch:  8 | t:  70 | loss: 0.780 | input:  98-01-12 | target:  12/Jan/1998 | inference:  21/Jul/1997<EOS>
Epoch:  9 | t:  0 | loss: 0.763 | input:  32-06-24 | target:  24/Jun/2032 | inference:  26/Mar/2024<EOS>
Epoch:  9 | t:  70 | loss: 0.718 | input:  76-07-27 | target:  27/Jul/1976 | inference:  27/May/198<EOS>
Epoch:  10 | t:  0 | loss: 0.700 | input:  89-12-19 | target:  19/Dec/1989 | inference:  19/Jan/1997<EOS>
Epoch:  10 | t:  70 | loss: 0.670 | input:  78-06-15 | target:  15/Jun/1978 | inference:  16/May/1994<EOS>
Epoch:  11 | t:  0 | loss: 0.659 | input:  13-05-16 | target:  16/May/2013 | inference:  16/Mar/2017<EOS>
Epoch:  11 | t:  70 | loss: 0.619 | input:  88-05-31 | target:  31/May/1988 | inference:  12/Jul/1997<EOS>
Epoch:  12 | t:  0 | loss: 0.585 | input:  08-09-09 | target:  09/Sep/2008 | inference:  09/Jan/2007<EOS>
Epoch:  12 | t:  70 | loss: 0.572 | input:  93-09-30 | target:  30/Sep/1993 | inference:  20/Jan/1997<EOS>
Epoch:  13 | t:  0 | loss: 0.557 | input:  75-10-15 | target:  15/Oct/1975 | inference:  14/May/1976<EOS>
Epoch:  13 | t:  70 | loss: 0.539 | input:  25-02-03 | target:  03/Feb/2025 | inference:  03/Mar/2026<EOS>
Epoch:  14 | t:  0 | loss: 0.526 | input:  99-09-22 | target:  22/Sep/1999 | inference:  22/Mar/1999<EOS>
Epoch:  14 | t:  70 | loss: 0.500 | input:  08-09-11 | target:  11/Sep/2008 | inference:  11/Jan/2007<EOS>
Epoch:  15 | t:  0 | loss: 0.495 | input:  32-02-05 | target:  05/Feb/2032 | inference:  04/Mar/2027<EOS>
Epoch:  15 | t:  70 | loss: 0.489 | input:  11-04-09 | target:  09/Apr/2011 | inference:  09/Sep/2019<EOS>
Epoch:  16 | t:  0 | loss: 0.475 | input:  03-07-19 | target:  19/Jul/2003 | inference:  19/Jan/2001<EOS>
Epoch:  16 | t:  70 | loss: 0.458 | input:  19-07-12 | target:  12/Jul/2019 | inference:  12/Jan/2019<EOS>
Epoch:  17 | t:  0 | loss: 0.448 | input:  23-07-26 | target:  26/Jul/2023 | inference:  26/Mar/2023<EOS>
Epoch:  17 | t:  70 | loss: 0.433 | input:  74-12-29 | target:  29/Dec/1974 | inference:  29/Jan/1976<EOS>
Epoch:  18 | t:  0 | loss: 0.421 | input:  27-09-16 | target:  16/Sep/2027 | inference:  16/Jan/2027<EOS>
Epoch:  18 | t:  70 | loss: 0.389 | input:  29-05-26 | target:  26/May/2029 | inference:  26/Mar/2023<EOS>
Epoch:  19 | t:  0 | loss: 0.388 | input:  02-06-11 | target:  11/Jun/2002 | inference:  11/Sep/2002<EOS>
Epoch:  19 | t:  70 | loss: 0.367 | input:  26-06-22 | target:  22/Jun/2026 | inference:  22/Mar/2026<EOS>
Epoch:  20 | t:  0 | loss: 0.362 | input:  86-06-08 | target:  08/Jun/1986 | inference:  08/Aug/1984<EOS>
Epoch:  20 | t:  70 | loss: 0.347 | input:  26-05-08 | target:  08/May/2026 | inference:  08/Sep/2026<EOS>
Epoch:  21 | t:  0 | loss: 0.338 | input:  99-08-09 | target:  09/Aug/1999 | inference:  09/Aug/1999<EOS>
Epoch:  21 | t:  70 | loss: 0.319 | input:  13-12-11 | target:  11/Dec/2013 | inference:  11/Nov/2013<EOS>
Epoch:  22 | t:  0 | loss: 0.315 | input:  91-12-11 | target:  11/Dec/1991 | inference:  11/Nov/1991<EOS>
Epoch:  22 | t:  70 | loss: 0.295 | input:  10-12-31 | target:  31/Dec/2010 | inference:  31/Nov/2011<EOS>
Epoch:  23 | t:  0 | loss: 0.269 | input:  30-11-29 | target:  29/Nov/2030 | inference:  29/Nov/2030<EOS>
Epoch:  23 | t:  70 | loss: 0.268 | input:  17-02-08 | target:  08/Feb/2017 | inference:  08/Feb/2017<EOS>
Epoch:  24 | t:  0 | loss: 0.259 | input:  78-08-21 | target:  21/Aug/1978 | inference:  21/Aug/1978<EOS>
Epoch:  24 | t:  70 | loss: 0.234 | input:  32-12-01 | target:  01/Dec/2032 | inference:  01/Dec/2032<EOS>
Epoch:  25 | t:  0 | loss: 0.245 | input:  83-04-30 | target:  30/Apr/1983 | inference:  30/Mar/1983<EOS>
Epoch:  25 | t:  70 | loss: 0.226 | input:  31-10-09 | target:  09/Oct/2031 | inference:  09/Nov/2031<EOS>
Epoch:  26 | t:  0 | loss: 0.222 | input:  15-04-09 | target:  09/Apr/2015 | inference:  09/May/2015<EOS>
Epoch:  26 | t:  70 | loss: 0.200 | input:  13-12-22 | target:  22/Dec/2013 | inference:  22/Dec/2013<EOS>
Epoch:  27 | t:  0 | loss: 0.200 | input:  21-03-09 | target:  09/Mar/2021 | inference:  09/Feb/2021<EOS>
Epoch:  27 | t:  70 | loss: 0.182 | input:  14-05-26 | target:  26/May/2014 | inference:  26/May/2014<EOS>
Epoch:  28 | t:  0 | loss: 0.185 | input:  20-12-27 | target:  27/Dec/2020 | inference:  27/Dec/2020<EOS>
Epoch:  28 | t:  70 | loss: 0.173 | input:  02-05-23 | target:  23/May/2002 | inference:  23/May/2002<EOS>
Epoch:  29 | t:  0 | loss: 0.169 | input:  92-12-28 | target:  28/Dec/1992 | inference:  28/Dec/1992<EOS>
Epoch:  29 | t:  70 | loss: 0.156 | input:  25-08-28 | target:  28/Aug/2025 | inference:  28/Sep/2025<EOS>
Epoch:  30 | t:  0 | loss: 0.154 | input:  95-02-21 | target:  21/Feb/1995 | inference:  21/Feb/1995<EOS>
Epoch:  30 | t:  70 | loss: 0.168 | input:  84-11-11 | target:  11/Nov/1984 | inference:  11/Nov/1984<EOS>
Epoch:  31 | t:  0 | loss: 0.143 | input:  28-11-14 | target:  14/Nov/2028 | inference:  14/Nov/2028<EOS>
Epoch:  31 | t:  70 | loss: 0.137 | input:  18-02-02 | target:  02/Feb/2018 | inference:  02/Feb/2018<EOS>
Epoch:  32 | t:  0 | loss: 0.121 | input:  91-05-11 | target:  11/May/1991 | inference:  11/May/1991<EOS>
Epoch:  32 | t:  70 | loss: 0.132 | input:  14-03-19 | target:  19/Mar/2014 | inference:  19/Feb/2014<EOS>
Epoch:  33 | t:  0 | loss: 0.123 | input:  24-12-04 | target:  04/Dec/2024 | inference:  04/Dec/2024<EOS>
Epoch:  33 | t:  70 | loss: 0.115 | input:  24-07-22 | target:  22/Jul/2024 | inference:  22/Mar/2024<EOS>
Epoch:  34 | t:  0 | loss: 0.124 | input:  27-09-04 | target:  04/Sep/2027 | inference:  04/Sep/2027<EOS>
Epoch:  34 | t:  70 | loss: 0.113 | input:  80-08-29 | target:  29/Aug/1980 | inference:  29/Aug/1980<EOS>
Epoch:  35 | t:  0 | loss: 0.097 | input:  13-02-12 | target:  12/Feb/2013 | inference:  12/Feb/2013<EOS>
Epoch:  35 | t:  70 | loss: 0.097 | input:  80-12-13 | target:  13/Dec/1980 | inference:  13/Dec/1980<EOS>
Epoch:  36 | t:  0 | loss: 0.098 | input:  27-02-04 | target:  04/Feb/2027 | inference:  04/Feb/2027<EOS>
Epoch:  36 | t:  70 | loss: 0.098 | input:  12-12-29 | target:  29/Dec/2012 | inference:  29/Dec/2012<EOS>
Epoch:  37 | t:  0 | loss: 0.097 | input:  94-12-01 | target:  01/Dec/1994 | inference:  01/Dec/1994<EOS>
Epoch:  37 | t:  70 | loss: 0.081 | input:  30-11-29 | target:  29/Nov/2030 | inference:  29/Nov/2030<EOS>
Epoch:  38 | t:  0 | loss: 0.077 | input:  05-03-03 | target:  03/Mar/2005 | inference:  03/Mar/2005<EOS>
Epoch:  38 | t:  70 | loss: 0.094 | input:  78-07-31 | target:  31/Jul/1978 | inference:  31/Mar/1978<EOS>
Epoch:  39 | t:  0 | loss: 0.074 | input:  01-12-04 | target:  04/Dec/2001 | inference:  04/Dec/2001<EOS>
Epoch:  39 | t:  70 | loss: 0.072 | input:  88-10-03 | target:  03/Oct/1988 | inference:  03/Oct/1988<EOS>

总结

通过上面的理解我们知道了,Encoder负责理解上文,Decoder负责将思考怎么样在理解的句子的基础上做任务

Decoder是解压器,但是它并不是将压缩好的信息还原,而是解压成另外一种形式,换一种表达方式

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值