题目
'''
Description: seq2seq代码详细分析
Autor: 365JHWZGo
Date: 2021-12-16 19:59:38
LastEditors: 365JHWZGo
LastEditTime: 2021-12-24 20:39:34
'''
代码
from torch import nn
import torch
import utils
from torch.utils.data import DataLoader
from torch.nn.functional import cross_entropy
class Seq2Seq(nn.Module):
def __init__(self, enc_v_dim, dec_v_dim, emb_dim, BATCH_SIZE, max_pred_len, start_token, end_token):
# enc_v_dim:encoder_vector_dim 编码器的输入维度
# dec_v_dim:decoder_vector_dim 解码器的输入维度
# emb_dim:embedding_dim 词嵌入的维度
# BATCH_SIZE BATCH_SIZE
# max_pred_len 最大预测长度
# start_token 开始标志
# end_token 结束标志
super(Seq2Seq, self).__init__()
self.BATCH_SIZE = BATCH_SIZE
self.dec_v_dim = dec_v_dim
self.max_pred_len = max_pred_len
self.start_token = start_token
self.end_token = end_token
# encoder
# 创建一个词嵌入模型,有enc_v_dim个单词,每个单词用emb_dim维表示
self.enc_embeddings = nn.Embedding(enc_v_dim, emb_dim)
# 初始化词嵌入模型的weight
self.enc_embeddings.weight.data.normal_(0, 0.1)
# 创建LSTM,word_dim=emv_dim,hidden_size=BATCH_SIZE,num_layer=1
self.encoder = nn.LSTM(emb_dim, BATCH_SIZE, 1, batch_first=True)
# decoder
# 创建一个词嵌入模型,有dec_v_dim个单词,每个单词用emb_dim维表示
self.dec_embeddings = nn.Embedding(dec_v_dim, emb_dim)
# 初始化词嵌入模型的weight
self.dec_embeddings.weight.data.normal_(0, 0.1)
# 创建decoder_cell,LSTMCell输入维度word_dim=emb_dim, hidden_size=BATCH_SIZE
self.decoder_cell = nn.LSTMCell(emb_dim, BATCH_SIZE)
# 创建decoder_dense,将BATCH_SIZE的输入维度转化为dec_v_dim
self.decoder_dense = nn.Linear(BATCH_SIZE, dec_v_dim)
# 创建优化器
self.opt = torch.optim.Adam(self.parameters(), lr=0.001)
def encode(self, x):
# x的维度BATCH_SIZE*enc_v_dim
# embedded的维度BATCH_SIZE*enc_v_dim*emb_dim
embedded = self.enc_embeddings(x)
# hidden维度2*num_layer*BATCH_SIZE*hidden_size
hidden = (torch.zeros(1, x.shape[0], self.BATCH_SIZE), torch.zeros(
1, x.shape[0], self.BATCH_SIZE))
o, (h, c) = self.encoder(embedded, hidden)
return h, c
# model.train()和model.eval()分别在训练和测试中都要写,它们的作用如下:
# (1). model.train()
# 启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为True
# (2). model.eval()
# 不启用 BatchNormalization 和 Dropout,将BatchNormalization和Dropout置为False
# (1). 在训练模块中千万不要忘了写model.train()
# (2). 在评估(或测试)模块千万不要忘了写model.eval()
def inference(self, x):
self.eval()
# hx,cx的维度1*1*hidden_size
hx, cx = self.encode(x)
# hx,cx的维度1*hidden_size
hx, cx = hx[0], cx[0]
# start给句子的创建一个开始字符
start = torch.ones(x.shape[0], 1)
start[:, 0] = torch.tensor(self.start_token)
start = start.type(torch.LongTensor)
# start维度(1,1)
dec_emb_in = self.dec_embeddings(start)
# 上一步dec_emb_in维度1*1*emb_dim
# permute()将tensor的维度换位
# 此时dec_emb_in维度1*1*emb_dim
dec_emb_in = dec_emb_in.permute(1, 0, 2)
dec_in = dec_emb_in[0]
# dec_in的维度1*emb_dim
output = []
for i in range(self.max_pred_len):
hx, cx = self.decoder_cell(dec_in, (hx, cx))
# hx的维度1*hidden_size
o = self.decoder_dense(hx)
# o的维度1*dec_v_dim
o = o.argmax(dim=1).view(-1, 1)
# o的维度1*1=(BATCH_SIZE*seq_len)
dec_in = self.dec_embeddings(o).permute(1, 0, 2)[0]
# dec_in的维度1*emb_dim(BATCH_SIZE*emb_dim)
output.append(o)
output = torch.stack(output, dim=0)
# output的维度max_pred_len*1*1
self.train()
# output的维度BATCH_SIZE*max_pred_len=(1*11)
# 每一个批次预测出来的max_pre_len个单词预测出来的最大下标
return output.permute(1, 0, 2).view(-1, self.max_pred_len)
def train_logit(self, x, y):
hx, cx = self.encode(x)
# hx、cx维度BATCH_SIZE*hidden_size
hx, cx = hx[0], cx[0]
# dec_in作为decoder的输入,其作用是预测下一个可能的结果,所以预测次数为(max_pred_len-1)
# dec_in 维度BATCH_SIZE*(max_pred_len-1)
dec_in = y[:, :-1]
# dec_emb_in 维度BATCH_SIZE*(max_pred_len-1)*emb_dim
dec_emb_in = self.dec_embeddings(dec_in)
# dec_emb_in 维度(max_pred_len-1)*BATCH_SIZE*emb_dim
dec_emb_in = dec_emb_in.permute(1, 0, 2)
output = []
for i in range(dec_emb_in.shape[0]):
hx, cx = self.decoder_cell(dec_emb_in[i], (hx, cx))
o = self.decoder_dense(hx)
# o的维度为BATCH_SIZE*dec_v_dim
output.append(o)
# output的维度(max_pred_len*-1)*BATCH_SIZE*dec_v_dim
output = torch.stack(output, dim=0)
# output的维度BATCH_SIZE*(max_pred_len*-1)*dec_v_dim
return output.permute(1, 0, 2)
def step(self, x, y):
# x的维度BATCH_SIZE*max_pre_len
# y的维度BATCH_SIZE*max_pre_len
# 优化器梯度清零
self.opt.zero_grad()
# logit的维度BATCH_SIZE*(max_pred_len*-1)*dec_v_dim
# logit是通过训练得到的结果,max_pred_len*-1是除去开始字符后的字符串
logit = self.train_logit(x, y)
# 因为logit输出的是没有开始字符的结果,所以标签y也需要去除开始字符串
# dec_out维度BATCH_SIZE*(max_pred_len*-1)
dec_out = y[:, 1:]
# 计算误差
loss = cross_entropy(
logit.reshape(-1, self.dec_v_dim), dec_out.reshape(-1))
loss.backward()
self.opt.step()
return loss.detach().numpy()
# 创建数据集
dataset = utils.DateData(4000)
# 创建数据加载器
loader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=True
)
# 创建seq2seq实例
model = Seq2Seq(
enc_v_dim=dataset.num_word,
dec_v_dim=dataset.num_word,
emb_dim=16,
BATCH_SIZE=32,
max_pred_len=11,
start_token=dataset.start_token,
end_token=dataset.end_token
)
def train():
# print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
# print("Vocabularies: ", dataset.vocab)
# print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
# f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
for i in range(40):
for batch_idx, batch in enumerate(loader):
# batch 包含三部分输入数据【汉语数据】,标签【英文数据】,解码器长度
bx, by, decoder_len = batch
# bx的维度BATCH_SIZE*8
# by的维度BATCH_SIZE*11
bx = bx.type(torch.LongTensor)
by = by.type(torch.LongTensor)
# 计算误差
loss = model.step(bx, by)
if batch_idx % 70 == 0:
# 输入 src
src = dataset.idx2str(bx[0].data.numpy())
# 目标输出 target
target = dataset.idx2str(by[0, 1:-1].data.numpy())
# 实际输出 res
# bx[0:1].shape=torch.Size([1, 8])
res = dataset.idx2str(model.inference(bx[0:1])[0].data.numpy())
print(
"Epoch: ", i,
"| t: ", batch_idx,
"| loss: %.3f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)
if __name__ == "__main__":
train()
总结
接下来,再总结为什么要这样写代码!
未完待续~