pytorch实现seq2seq(二):Luong_attention机制

       pytorch实现seq2seq(一)
       本篇实现的是Luong的attention,即:
在这里插入图片描述
       其中 h ˉ s \bar{h}_s hˉs表示encoder每个hidden_state的输出, h t h_t ht表示decoder每个hidden_state的输出。
在这里插入图片描述

import os
import sys
import math
from collections import Counter
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F

import nltk
import jieba

1.载入原始数据

def  load_data(file_path): 
    if 'en' in file_path:
        result = []
        with open(file_path,'r') as f:
            for line in f:
                result.append(['BOS']+nltk.word_tokenize(line.lower())+['EOS'])
    if 'zh' in file_path:
        result = []
        with open(file_path,'r',encoding='utf-8') as f:
            for line in f:
                result.append(['BOS']+[c for c in line]+['EOS'])
    return result


ch_text = []
with open('./seq2seq_data/train_zh.txt','r',encoding='utf-8') as f:
    for line in f:
        ch_text.append(line)
        
en_text = []
with open('./seq2seq_data/train_en.txt','r',encoding='utf-8') as f:
    for line in f:
        en_text.append(line)
print(en_text[:5])
print(ch_text[:5])
['A pair of red - crowned cranes have staked out their nesting territory\n', 'A pair of crows had come to nest on our roof as if they had come for Lhamo.\n', "A couple of boys driving around in daddy's car.\n", 'A pair of nines? You pushed in with a pair of nines?\n', 'Fighting two against one is never ideal,\n']
['一对丹顶鹤正监视着它们的筑巢领地\n', '一对乌鸦飞到我们屋顶上的巢里,它们好像专门为拉木而来的。\n', '一对乖乖仔开着老爸的车子。\n', '一对九?一对九你就全下注了?\n', '一对二总不是好事,\n']

2.数据预处理

2.1 中英文分词

# 由于原始数据集太大(中英文都是1千万句话,所以这里我们只选择其中2w个来训练)
choose_idx = random.sample(range(10000000),20000)

choose_ch = [ch_text[i].strip('\n').strip() for i in choose_idx]
choose_en = [en_text[i].strip('\n').strip() for i in choose_idx]

ch_token = []
for sentence in choose_ch:
    ch_token.append(['BOS']+[c for c in sentence]+['EOS'])

en_token = []
for sentence in choose_en:
    en_token.append(['BOS']+nltk.word_tokenize(sentence.lower())+['EOS'])    

2.2 建立词典

UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
    word_count = Counter()
    for sentence in sentences:
        for s in sentence:
            word_count[s] += 1
    ls = word_count.most_common(max_words)
    total_words = len(ls) + 2
    word_dict = {
   w[0]: index+2 for index, w in enumerate(ls)}
    # w[0]是单词,w[1]是单词出现的次数。 index需要加2的原因是字典的前两个位置留个了[unk]和[PAD]
    word_dict["UNK"] = UNK_IDX
    word_dict["PAD"] = PAD_IDX
    return word_dict, total_words
    #建立好的字典中 key是单词,value位置索引,每一个value都是不同的


# stoi  key是单词 value是index
ch_dict,len_ch_dict = build_dict(ch_token)
en_dict,len_en_dict = build_dict(en_token)
# itos key是index,value是单词
inv_en_dict = {
   v: k for k, v in en_dict.items()}
inv_ch_dict = {
   v: k for k, v in ch_dict.items()}

2.3 使用词典来对原始句子进行编码

def encode(en_sentences, ch_sentences, en_dict, cn_dict, sort_by_len=True):
    #Encode the sequences. 
    # 把每句话中的每个单词用它在dict中的value来编码
    out_en_sentences = [[en_dict.get(w, 0) for w in sent] for sent in en_sentences]
    out_ch_sentences = [[ch_dict.get(w, 0) for w in sent] for sent in ch_sentences]

    # sort sentences by english lengths
    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))
                # key指定了range(len(seq))这len(seq)个数按照什么方式来排序,其中x就是从0到len(seq)-1中的每一个数
       
    # 把中文和英文按照同样的顺序排序
    if sort_by_len:
        sorted_index = len_argsort(out_en_sentences)
        out_en_sentences = [out_en_sentences[i] for i in sorted_index]
        out_ch_sentences = [out_ch_sentences[i] for i in sorted_index]
        
    return out_en_sentences, out_ch_sentences

en_encode, ch_encode = encode(en_token, ch_token, en_dict, ch_dict)
k = 805
print(" ".join([inv_ch_dict[i] for i in ch_encode[k]]))
print(" ".join([inv_en_dict[i] for i in en_encode
  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值