seq2seq翻译任务代码详细分析

本文详细介绍了使用PyTorch构建Seq2Seq模型的过程,包括编码器、解码器的搭建,以及训练和推断的实现。模型采用LSTM作为核心组件,并利用Adam优化器进行参数更新。在训练过程中,通过交叉熵损失函数计算误差,并进行反向传播。代码清晰展示了如何处理输入序列和预测输出序列。
摘要由CSDN通过智能技术生成

文章目录

题目

'''
Description: seq2seq代码详细分析
Autor: 365JHWZGo
Date: 2021-12-16 19:59:38
LastEditors: 365JHWZGo
LastEditTime: 2021-12-24 20:39:34
'''

代码

from torch import nn
import torch
import utils
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy


class Seq2Seq(nn.Module):
    def __init__(self, enc_v_dim, dec_v_dim, emb_dim, BATCH_SIZE, max_pred_len, start_token, end_token):
        # enc_v_dim:encoder_vector_dim 编码器的输入维度
        # dec_v_dim:decoder_vector_dim 解码器的输入维度
        # emb_dim:embedding_dim        词嵌入的维度
        # BATCH_SIZE                   BATCH_SIZE
        # max_pred_len                 最大预测长度
        # start_token                  开始标志
        # end_token                    结束标志

        super(Seq2Seq, self).__init__()
        self.BATCH_SIZE = BATCH_SIZE
        self.dec_v_dim = dec_v_dim
        self.max_pred_len = max_pred_len
        self.start_token = start_token
        self.end_token = end_token

        # encoder
        # 创建一个词嵌入模型,有enc_v_dim个单词,每个单词用emb_dim维表示
        self.enc_embeddings = nn.Embedding(enc_v_dim, emb_dim)
        # 初始化词嵌入模型的weight
        self.enc_embeddings.weight.data.normal_(0, 0.1)
        # 创建LSTM,word_dim=emv_dim,hidden_size=BATCH_SIZE,num_layer=1
        self.encoder = nn.LSTM(emb_dim, BATCH_SIZE, 1, batch_first=True)

        # decoder
        # 创建一个词嵌入模型,有dec_v_dim个单词,每个单词用emb_dim维表示
        self.dec_embeddings = nn.Embedding(dec_v_dim, emb_dim)
        # 初始化词嵌入模型的weight
        self.dec_embeddings.weight.data.normal_(0, 0.1)
        # 创建decoder_cell,LSTMCell输入维度word_dim=emb_dim, hidden_size=BATCH_SIZE
        self.decoder_cell = nn.LSTMCell(emb_dim, BATCH_SIZE)
        # 创建decoder_dense,将BATCH_SIZE的输入维度转化为dec_v_dim
        self.decoder_dense = nn.Linear(BATCH_SIZE, dec_v_dim)

        # 创建优化器
        self.opt = torch.optim.Adam(self.parameters(), lr=0.001)

    def encode(self, x):
        # x的维度BATCH_SIZE*enc_v_dim
        # embedded的维度BATCH_SIZE*enc_v_dim*emb_dim
        embedded = self.enc_embeddings(x)
        # hidden维度2*num_layer*BATCH_SIZE*hidden_size
        hidden = (torch.zeros(1, x.shape[0], self.BATCH_SIZE), torch.zeros(
            1, x.shape[0], self.BATCH_SIZE))
        o, (h, c) = self.encoder(embedded, hidden)
        return h, c

    # model.train()和model.eval()分别在训练和测试中都要写,它们的作用如下:
    # (1). model.train()
    # 启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
    # (2). model.eval()
    # 不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False
    # (1). 在训练模块中千万不要忘了写model.train()
    # (2). 在评估(或测试)模块千万不要忘了写model.eval()
    def inference(self, x):
        self.eval()
        # hx,cx的维度1*1*hidden_size
        hx, cx = self.encode(x)
        # hx,cx的维度1*hidden_size
        hx, cx = hx[0], cx[0]
        # start给句子的创建一个开始字符
        start = torch.ones(x.shape[0], 1)
        start[:, 0] = torch.tensor(self.start_token)
        start = start.type(torch.LongTensor)
        # start维度(1,1)
        dec_emb_in = self.dec_embeddings(start)
        # 上一步dec_emb_in维度1*1*emb_dim
        # permute()将tensor的维度换位
        # 此时dec_emb_in维度1*1*emb_dim
        dec_emb_in = dec_emb_in.permute(1, 0, 2)
        dec_in = dec_emb_in[0]
        # dec_in的维度1*emb_dim
        output = []
        for i in range(self.max_pred_len):
            hx, cx = self.decoder_cell(dec_in, (hx, cx))
            # hx的维度1*hidden_size
            o = self.decoder_dense(hx)
            # o的维度1*dec_v_dim
            o = o.argmax(dim=1).view(-1, 1)
            # o的维度1*1=(BATCH_SIZE*seq_len)
            dec_in = self.dec_embeddings(o).permute(1, 0, 2)[0]
            # dec_in的维度1*emb_dim(BATCH_SIZE*emb_dim)
            output.append(o)
        output = torch.stack(output, dim=0)
        # output的维度max_pred_len*1*1
        self.train()
        # output的维度BATCH_SIZE*max_pred_len=(1*11)
        # 每一个批次预测出来的max_pre_len个单词预测出来的最大下标
        return output.permute(1, 0, 2).view(-1, self.max_pred_len)

    def train_logit(self, x, y):
        hx, cx = self.encode(x)
        # hx、cx维度BATCH_SIZE*hidden_size
        hx, cx = hx[0], cx[0]
        # dec_in作为decoder的输入,其作用是预测下一个可能的结果,所以预测次数为(max_pred_len-1)
        # dec_in 维度BATCH_SIZE*(max_pred_len-1)
        dec_in = y[:, :-1]
        # dec_emb_in 维度BATCH_SIZE*(max_pred_len-1)*emb_dim
        dec_emb_in = self.dec_embeddings(dec_in)
        # dec_emb_in 维度(max_pred_len-1)*BATCH_SIZE*emb_dim
        dec_emb_in = dec_emb_in.permute(1, 0, 2)
        output = []
        for i in range(dec_emb_in.shape[0]):
            hx, cx = self.decoder_cell(dec_emb_in[i], (hx, cx))
            o = self.decoder_dense(hx)
            # o的维度为BATCH_SIZE*dec_v_dim
            output.append(o)
        # output的维度(max_pred_len*-1)*BATCH_SIZE*dec_v_dim
        output = torch.stack(output, dim=0)
        # output的维度BATCH_SIZE*(max_pred_len*-1)*dec_v_dim
        return output.permute(1, 0, 2)

    def step(self, x, y):
        # x的维度BATCH_SIZE*max_pre_len
        # y的维度BATCH_SIZE*max_pre_len
        
        # 优化器梯度清零
        self.opt.zero_grad()
        
        # logit的维度BATCH_SIZE*(max_pred_len*-1)*dec_v_dim
        # logit是通过训练得到的结果,max_pred_len*-1是除去开始字符后的字符串
        logit = self.train_logit(x, y)
        
        # 因为logit输出的是没有开始字符的结果,所以标签y也需要去除开始字符串
        # dec_out维度BATCH_SIZE*(max_pred_len*-1)
        dec_out = y[:, 1:]

        # 计算误差
        loss = cross_entropy(
            logit.reshape(-1, self.dec_v_dim), dec_out.reshape(-1))
        loss.backward()
        self.opt.step()
        return loss.detach().numpy()


# 创建数据集
dataset = utils.DateData(4000)

# 创建数据加载器
loader = DataLoader(
    dataset=dataset,
    batch_size=32,
    shuffle=True
)

# 创建seq2seq实例
model = Seq2Seq(
    enc_v_dim=dataset.num_word,
    dec_v_dim=dataset.num_word,
    emb_dim=16,
    BATCH_SIZE=32,
    max_pred_len=11,
    start_token=dataset.start_token,
    end_token=dataset.end_token
)

def train():

    # print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
    # print("Vocabularies: ", dataset.vocab)
    # print(f"x index sample:  \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
    # f"\ny index sample:  \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")

    for i in range(40):
        for batch_idx, batch in enumerate(loader):
            # batch 包含三部分输入数据【汉语数据】,标签【英文数据】,解码器长度
            bx, by, decoder_len = batch
            # bx的维度BATCH_SIZE*8
            # by的维度BATCH_SIZE*11
            bx = bx.type(torch.LongTensor)
            by = by.type(torch.LongTensor)
            # 计算误差
            loss = model.step(bx, by)
            if batch_idx % 70 == 0:
                # 输入 src
                src = dataset.idx2str(bx[0].data.numpy())
                # 目标输出 target
                target = dataset.idx2str(by[0, 1:-1].data.numpy())
                # 实际输出 res
                # bx[0:1].shape=torch.Size([1, 8])
                res = dataset.idx2str(model.inference(bx[0:1])[0].data.numpy())
                
                print(
                    "Epoch: ", i,
                    "| t: ", batch_idx,
                    "| loss: %.3f" % loss,
                    "| input: ", src,
                    "| target: ", target,
                    "| inference: ", res,
                )


if __name__ == "__main__":
    train()

总结

接下来,再总结为什么要这样写代码!
未完待续~

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值