【参考:Seq2Seq 机器翻译, 全程手写代码_哔哩哔哩_bilibili】
【参考:shouxieai/seq2seq_translation: seq2seq_translation】
数据
translate.csv
english,chinese
Hi.,嗨。
Hi.,你好。
Run.,你用跑的。
Wait!,等等!
Wait!,等一下!
Begin.,开始!
Begin,开始
Hello!,你好。
I try.,我试试。
I won!,我赢了。
Oh no!,不会吧。
Cheers!,乾杯!
Got it?,你懂了吗?
He ran.,他跑了。
代码
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import pickle # 加载数据
def get_datas(file="datas\\translate.csv", nums=None):
all_datas = pd.read_csv(file)
en_datas = list(all_datas["english"])
ch_datas = list(all_datas["chinese"])
if nums == None:
return en_datas, ch_datas
else:
return en_datas[:nums], ch_datas[:nums]
class MyDataset(Dataset):
def __init__(self, en_data, ch_data, en_word_2_index, ch_word_2_index):
self.en_data = en_data
self.ch_data = ch_data
self.en_word_2_index = en_word_2_index
self.ch_word_2_index = ch_word_2_index
def __getitem__(self, index):
en = self.en_data[index]
ch = self.ch_data[index]
en_index = [self.en_word_2_index[i] for i in en]
ch_index = [self.ch_word_2_index[i] for i in ch]
return en_index, ch_index
def batch_data_process(self, batch_datas):
"""
使每一个batch内的数据长度一致
:param batch_datas:
:return:
"""
global device
en_index, ch_index = [], []
en_len, ch_len = [], []
for en, ch in batch_datas:
en_index.append(en)
ch_index.append(ch)
en_len.append(len(en))
ch_len.append(len(ch))
# 选择batch的最长的那个数据的长度
max_en_len = max(en_len)
max_ch_len = max(ch_len)
en_index = [i + [self.en_word_2_index["<PAD>"]] * (max_en_len - len(i)) for i in en_index]
# " BOS ch_index EOS 若干PAD "
ch_index = [
[self.ch_word_2_index["<BOS>"]] + i + [self.ch_word_2_index["<EOS>"]] + [self.ch_word_2_index["<PAD>"]] * (
max_ch_len - len(i)) for i in ch_index]
en_index = torch.tensor(en_index, device=device)
ch_index = torch.tensor(ch_index, device=device)
return en_index, ch_index
def __len__(self):
assert len(self.en_data) == len(self.ch_data)
return len(self.ch_data)
class Encoder(nn.Module):
def __init__(self, encoder_embedding_num, encoder_hidden_num, en_corpus_len):
super().__init__()
self.embedding = nn.Embedding(en_corpus_len, encoder_embedding_num)
self.lstm = nn.LSTM(encoder_embedding_num, encoder_hidden_num, batch_first=True)
def forward(self, en_index):
en_embedding = self.embedding(en_index)
_, encoder_hidden = self.lstm(en_embedding)
return encoder_hidden # (h_n, c_n)
class Decoder(nn.Module):
def __init__(self, decoder_embedding_num, decoder_hidden_num, ch_corpus_len):
super().__init__()
self.embedding = nn.Embedding(ch_corpus_len, decoder_embedding_num)
self.lstm = nn.LSTM(decoder_embedding_num, decoder_hidden_num, batch_first=True)
def forward(self, decoder_input, hidden): # hidden:(h_n, c_n)
embedding = self.embedding(decoder_input)
decoder_output, decoder_hidden = self.lstm(embedding, hidden)
return decoder_output, decoder_hidden # output, (h_n, c_n)
class Seq2Seq(nn.Module):
def __init__(self, encoder_embedding_num, encoder_hidden_num, en_corpus_len,
decoder_embedding_num, decoder_hidden_num, ch_corpus_len):
super().__init__()
self.encoder = Encoder(encoder_embedding_num, encoder_hidden_num, en_corpus_len)
self.decoder = Decoder(decoder_embedding_num, decoder_hidden_num, ch_corpus_len)
self.classifier = nn.Linear(decoder_hidden_num, ch_corpus_len)
self.cross_loss = nn.CrossEntropyLoss() # 交叉熵
def forward(self, en_index, ch_index):
decoder_input = ch_index[:, :-1] # 不要最后一个终结符<EOS>
label = ch_index[:, 1:] # 不要第一个开始符<BOS>
encoder_hidden = self.encoder(en_index)
decoder_output, _ = self.decoder(decoder_input, encoder_hidden)
pre = self.classifier(decoder_output) # 将最上面的一层的输出做分类
# pre:[batch_size,seq_len,ch_corpus_len] -> [batch_size*seq_len,ch_corpus_len]
# label:[batch_size,seq_len] -> [batch_size*seq_len]
loss = self.cross_loss(pre.reshape(-1, pre.shape[-1]), label.reshape(-1))
return loss
def translate(sentence):
"""
将英语句子翻译为中文
:param sentence: 英文句子 Str
:return:
"""
global en_word_2_index, model, device, ch_word_2_index, ch_index_2_word
# 最外层的[]是加一个batch
en_index = torch.tensor([[en_word_2_index[i] for i in sentence]], device=device)
result = []
# 先编码
encoder_hidden = model.encoder(en_index) # encoder_hidden: (h_n, c_n)
decoder_hidden = encoder_hidden
# 从<BOS>开始预测
decoder_input = torch.tensor([[ch_word_2_index["<BOS>"]]], device=device)
while True:
# 再解码
decoder_output, decoder_hidden = model.decoder(decoder_input, decoder_hidden)
# pre:[1,1,ch_corpus_len]
pre = model.classifier(decoder_output)
w_index = int(torch.argmax(pre, dim=-1)) # 求最大值的下标
word = ch_index_2_word[w_index] # 转换为中文
# 预测为<EOS> 或者 长度大于50
if word == "<EOS>" or len(result) > 50:
break
result.append(word)
# 将当前的输出作为下一时刻的输入
decoder_input = torch.tensor([[w_index]], device=device)
print("译文: ", "".join(result))
if __name__ == "__main__":
device = "cuda:0" if torch.cuda.is_available() else "cpu"
with open("datas\\ch.vec", "rb") as f1:
# ch_word_2_index 中文单个字
_, ch_word_2_index, ch_index_2_word = pickle.load(f1)
with open("datas\\en.vec", "rb") as f2:
# en_word_2_index 英文大小写字母+英文标点符号+空格
_, en_word_2_index, en_index_2_word = pickle.load(f2)
ch_corpus_len = len(ch_word_2_index)
en_corpus_len = len(en_word_2_index)
# 添加字符
ch_word_2_index.update({"<PAD>": ch_corpus_len,
"<BOS>": ch_corpus_len + 1, # 开始符
"<EOS>": ch_corpus_len + 2}) # 结束符
en_word_2_index.update({"<PAD>": en_corpus_len})
ch_index_2_word += ["<PAD>", "<BOS>", "<EOS>"]
en_index_2_word += ["<PAD>"]
ch_corpus_len += 3
en_corpus_len += 1
en_datas, ch_datas = get_datas(nums=200) # 这里为了加快速度,选择前200个
# 下面是训练代码
# encoder_embedding_num = 50
# encoder_hidden_num = 100
# decoder_embedding_num = 107
# decoder_hidden_num = 100
#
# batch_size = 2
# epoch = 40
# lr = 0.001
#
# dataset = MyDataset(en_datas, ch_datas, en_word_2_index, ch_word_2_index)
# dataloader = DataLoader(dataset, batch_size, shuffle=False, collate_fn=dataset.batch_data_process)
#
# model = Seq2Seq(encoder_embedding_num, encoder_hidden_num, en_corpus_len, decoder_embedding_num, decoder_hidden_num,
# ch_corpus_len)
# model = model.to(device)
#
# opt = torch.optim.Adam(model.parameters(), lr=lr)
#
# for e in range(epoch):
# for en_index, ch_index in dataloader:
# loss = model(en_index, ch_index)
# loss.backward()
# opt.step()
# opt.zero_grad()
#
# print(f"loss:{loss:.3f}")
#
# torch.save(model, "seq2seq.pt")
# print('save success')
# 下面是预测代码
model = torch.load('seq2seq.pt')
while True:
s = input("请输入英文: ")
translate(s)
总结
训练Seq2Seq就相当于用decoder接一个全连接层做多分类,而预测的时候则需要先输入一个字符进行预测,然后把这个预测输出的字符作为下一次的输出。