seq2seq机器翻译重要的核心模块是Encoder(LSTM)编码器和Decoder(LSTM)解码器,还需要一个神经网络模型Model(Encoder、Decoder)、以及index2word、word2index两个列表。
Encoder模型图
这里输入的是英文数据 output并不重要,我们需要的是hidden,然后把hidden传递个解码器Decoder。
本次采用字母的embedding为input,具体用单词还是字母是情况而定,本次是因为字母较少。
Decoder模型图
输入的hidden就是Encoder的hidden,Input就是开头BOS,然后预测出output(你),然后根据(你)字作为新的Input输入,然后预测出output(好),以此类推。
(<BOS 用来预测你,然后根据你预测下一个字好>)
红色的为Decoder的input。绿色的为Decoder的Output预测输出值。
为什么训练和预测时的Decoder不一样?
我们称这两种模式,根据标准答案来decode的方式为「teacher forcing」,而根据上一步的输出作为下一步输入的decode方式为「free running」。
其实,free running的模式真的不能在训练时使用吗?——当然是可以的!从理论上没有任何的问题,又不是不能跑。但是,在实践中人们发现,这样训练太南了。因为没有任何的引导,一开始会完全是瞎预测,正所谓“一步错,步步错”,而且越错越离谱,这样会导致训练时的累积损失太大(「误差爆炸」问题,exposure bias),训练起来就很费劲。这个时候,如果我们能够在每一步的预测时,让老师来指导一下,即提示一下上一个词的正确答案,decoder就可以快速步入正轨,训练过程也可以更快收敛。因此大家把这种方法称为teacher forcing。所以,这种操作的目的就是为了使得训练过程更容易。
具体代码
import torch
import pickle
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset,DataLoader
#获取数据函数
def get_datas(file = "datas/translate.csv",nums = None):
all_datas = pd.read_csv(file)
en_datas = list(all_datas['english'])
ch_datas = list(all_datas['chinese'])
if nums == None:
return en_datas,ch_datas
else:
return en_datas[:nums],ch_datas[:nums]
class MyDataset(Dataset):
# def __init__(self,en_data,ch_data,ch_word_2_index,en_word_2_index) -> None:
def __init__(self, en_data, ch_data, ch_word_2_index, en_word_2_index):
super().__init__()
self.en_data = en_data
self.ch_data = ch_data
self.ch_word_2_index = ch_word_2_index
self.en_word_2_index = en_word_2_index
def __getitem__(self, index):
en = self.en_data[index]
ch = self.ch_data[index]
# print(en)
en_index = [self.en_word_2_index[i] for i in en]
ch_index = [self.ch_word_2_index[i] for i in ch]
return en_index,ch_index
def batch_data_process(self,batch_datas):
global device
en_index,ch_index = [],[]
en_len,ch_len = [],[]
for en,ch in batch_datas:
en_index.append(en)
ch_index.append(ch)
en_len.append(len(en))
ch_len.append(len(ch))
max_en_len = max(en_len)
max_ch_len = max(ch_len)
# print(self.en_word_2_index['<PAD>'])
en_index = [ i + [self.en_word_2_index['<PAD>']]*(max_en_len - len(i)) for i in en_index]
ch_index = [[self.ch_word_2_index["<BOS>"]] + i + [self.ch_word_2_index["<EOS>"]] + [self.ch_word_2_index["<PAD>"]] * (max_ch_len - len(i)) for i in ch_index]
en_index = torch.tensor(en_index,device=device)
ch_index = torch.tensor(ch_index,device=device)
return en_index,ch_index
def __len__(self):
assert len(self.ch_data)==len(self.en_data)
return len(self.ch_data)
class Encoder(nn.Module):
#en_corpus_num词库大小
# def __init__(self,encoder_embedding_num,encoder_hidden_num,en_corpus_len) -> None:
def __init__(self, encoder_embedding_num, encoder_hidden_num, en_corpus_len):
super().__init__()
self.embedding = nn.Embedding(en_corpus_len,encoder_embedding_num)
self.lstm = nn.LSTM(encoder_embedding_num,encoder_hidden_num,batch_first=True)
def forward(self,en_index):
en_embedding = self.embedding(en_index)
#encoder_output 暂时不需要
encoder_output,encoder_hidden = self.lstm(en_embedding)
return encoder_hidden
class Decoder(nn.Module):
def __init__(self,decoder_embedding_num,decoder_hidden_num,ch_corpus_len):
super().__init__()
self.embedding = nn.Embedding(ch_corpus_len,decoder_embedding_num)
self.lstm = nn.LSTM(decoder_embedding_num,decoder_hidden_num,batch_first=True)
def forward(self,decoder_input,hidden):
embedding = self.embedding(decoder_input)
decoder_output,decoder_hidden = self.lstm(embedding,hidden)
return decoder_output,decoder_hidden
def translate(sentence):
global en_word_2_index,model,device,ch_word_2_index,ch_index_2_word
en_index = torch.tensor([[en_word_2_index[i] for i in sentence]],device=device)
result = []
encoder_hidden = model.encoder(en_index)
decoder_input = torch.tensor( [[ch_word_2_index["<BOS>"]]],device=device)
decoder_hidden = encoder_hidden
while True:
decoder_output,decoder_hidden = model.decoder(decoder_input,decoder_hidden)
pre = model.classifier(decoder_output)
w_index = int(torch.argmax(pre,dim=-1))
word = ch_index_2_word[w_index]
if word == "<EOS>" or len(result)>50:
break
result.append(word)
decoder_input = torch.tensor([[w_index]],device=device)
print("译文","".join(result))
class Seq2Seq(nn.Module):
def __init__(self,encoder_embedding_num,encoder_hidden_num,en_corpus_len,decoder_embedding_num,decoder_hidden_num,ch_corpus_len) -> None:
super().__init__()
self.encoder = Encoder(encoder_embedding_num,encoder_hidden_num,en_corpus_len)
self.decoder = Decoder(decoder_embedding_num,decoder_hidden_num,ch_corpus_len)
self.classifier = nn.Linear(decoder_hidden_num,ch_corpus_len) #分类器
self.corss_loss = nn.CrossEntropyLoss()
def forward(self,en_index,ch_index):
decoder_input = ch_index[:,:-1]
lable = ch_index[:,1:]
encoder_hidden = self.encoder(en_index)
decoder_output,_decoder_hidden= self.decoder(decoder_input,encoder_hidden)
pre = self.classifier(decoder_output)
loss = self.corss_loss(pre.reshape(-1,pre.shape[-1]),lable.reshape(-1))
return loss
if __name__ == "__main__":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
with open("datas/ch.vec","rb") as f1:
_,ch_word_2_index,ch_index_2_word = pickle.load(f1)
with open("datas/en.vec","rb") as f2:
_,en_word_2_index,en_index_2_word = pickle.load(f2)
ch_corpus_len = len(ch_word_2_index)
en_corpus_len = len(en_word_2_index)
ch_word_2_index.update({"<PAD>":ch_corpus_len,"<BOS>":ch_corpus_len+1,"<EOS>":ch_corpus_len+2})
en_word_2_index.update({"<PAD>":en_corpus_len})
ch_index_2_word += ["<PAD>","<BOS>","<EOS>"]
en_index_2_word += ["<PAD>"]
ch_corpus_len += 3
en_corpus_len = len(en_word_2_index)
en_datas,ch_datas = get_datas(nums=200)
encoder_embedding_num = 50
encoder_hidden_num = 100
decoder_embedding_num = 107
decoder_hidden_num = 100
batch_size = 2
epoch = 30
lr = 0.001
dataset = MyDataset(en_datas,ch_datas,ch_word_2_index,en_word_2_index)
dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=False,collate_fn=dataset.batch_data_process)
model = Seq2Seq(encoder_embedding_num,encoder_hidden_num,en_corpus_len,decoder_embedding_num,decoder_hidden_num,ch_corpus_len)
model = model.to(device)
#定义优化器
opt = torch.optim.Adam(model.parameters(),lr=lr)
for e in range(epoch):
for en_index,ch_index in dataloader:
loss = model(en_index,ch_index)
loss.backward()
opt.step()
opt.zero_grad() #提督清零
print(f"loss:{loss:.3f}")
while True:
s = input("请输入英文: ")
translate(s)
参考
https://zhuanlan.zhihu.com/p/147310766
https://www.bilibili.com/video/BV1hf4y1u7ez?p=8&vd_source=89280bb463d31a0af28a401901caf076