pytorch实现seq2seq(一)
本篇实现的是Luong的attention,即:
其中 h ˉ s \bar{h}_s hˉs表示encoder每个hidden_state的输出, h t h_t ht表示decoder每个hidden_state的输出。
文章目录
import os
import sys
import math
from collections import Counter
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import nltk
import jieba
1.载入原始数据
def load_data(file_path):
if 'en' in file_path:
result = []
with open(file_path,'r') as f:
for line in f:
result.append(['BOS']+nltk.word_tokenize(line.lower())+['EOS'])
if 'zh' in file_path:
result = []
with open(file_path,'r',encoding='utf-8') as f:
for line in f:
result.append(['BOS']+[c for c in line]+['EOS'])
return result
ch_text = []
with open('./seq2seq_data/train_zh.txt','r',encoding='utf-8') as f:
for line in f:
ch_text.append(line)
en_text = []
with open('./seq2seq_data/train_en.txt','r',encoding='utf-8') as f:
for line in f:
en_text.append(line)
print(en_text[:5])
print(ch_text[:5])
['A pair of red - crowned cranes have staked out their nesting territory\n', 'A pair of crows had come to nest on our roof as if they had come for Lhamo.\n', "A couple of boys driving around in daddy's car.\n", 'A pair of nines? You pushed in with a pair of nines?\n', 'Fighting two against one is never ideal,\n']
['一对丹顶鹤正监视着它们的筑巢领地\n', '一对乌鸦飞到我们屋顶上的巢里,它们好像专门为拉木而来的。\n', '一对乖乖仔开着老爸的车子。\n', '一对九?一对九你就全下注了?\n', '一对二总不是好事,\n']
2.数据预处理
2.1 中英文分词
# 由于原始数据集太大(中英文都是1千万句话,所以这里我们只选择其中2w个来训练)
choose_idx = random.sample(range(10000000),20000)
choose_ch = [ch_text[i].strip('\n').strip() for i in choose_idx]
choose_en = [en_text[i].strip('\n').strip() for i in choose_idx]
ch_token = []
for sentence in choose_ch:
ch_token.append(['BOS']+[c for c in sentence]+['EOS'])
en_token = []
for sentence in choose_en:
en_token.append(['BOS']+nltk.word_tokenize(sentence.lower())+['EOS'])
2.2 建立词典
UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
word_count = Counter()
for sentence in sentences:
for s in sentence:
word_count[s] += 1
ls = word_count.most_common(max_words)
total_words = len(ls) + 2
word_dict = {
w[0]: index+2 for index, w in enumerate(ls)}
# w[0]是单词,w[1]是单词出现的次数。 index需要加2的原因是字典的前两个位置留个了[unk]和[PAD]
word_dict["UNK"] = UNK_IDX
word_dict["PAD"] = PAD_IDX
return word_dict, total_words
#建立好的字典中 key是单词,value位置索引,每一个value都是不同的
# stoi key是单词 value是index
ch_dict,len_ch_dict = build_dict(ch_token)
en_dict,len_en_dict = build_dict(en_token)
# itos key是index,value是单词
inv_en_dict = {
v: k for k, v in en_dict.items()}
inv_ch_dict = {
v: k for k, v in ch_dict.items()}
2.3 使用词典来对原始句子进行编码
def encode(en_sentences, ch_sentences, en_dict, cn_dict, sort_by_len=True):
#Encode the sequences.
# 把每句话中的每个单词用它在dict中的value来编码
out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
out_ch_sentences = [[ch_dict.get(w, 0) for w in sent] for sent in ch_sentences]
# sort sentences by english lengths
def len_argsort(seq):
return sorted(range(len(seq)), key=lambda x: len(seq[x]))
# key指定了range(len(seq))这len(seq)个数按照什么方式来排序,其中x就是从0到len(seq)-1中的每一个数
# 把中文和英文按照同样的顺序排序
if sort_by_len:
sorted_index = len_argsort(out_en_sentences)
out_en_sentences = [out_en_sentences[i] for i in sorted_index]
out_ch_sentences = [out_ch_sentences[i] for i in sorted_index]
return out_en_sentences, out_ch_sentences
en_encode, ch_encode = encode(en_token, ch_token, en_dict, ch_dict)
k = 805
print(" ".join([inv_ch_dict[i] for i in ch_encode[k]]))
print(" ".join([inv_en_dict[i] for i in en_encode