利用Luong注意力机制进行seq2seq的机器翻译

本文介绍了如何运用Luong的注意力机制来实现seq2seq的机器翻译。首先,通过线性变换得到context_in,然后计算注意力权重并进行mask填充,接着将注意力权重与context_in相乘得到context。最后,将context与decoder的输出拼接并经过线性变换得到最终的输出。此外,文章还提到了其他三种注意力实现方式,并简单概述了无注意力的baseline实现。
摘要由CSDN通过智能技术生成

注意力的实现方式有很多种

我们这篇文章使用的Luong的注意力机制

 计算注意力 
        context>>>> [batch_size,seq,2*dechidien_size]
        context_in: torch.Size([64, 8, 100])  [batch_size,seq,deco_size]
        第一步 output: torch.Size([64, 11, 100]) [batch_size,seq_len,hidden_size]
                    
                    context_in: torch.Size([64, 100, 8])  [batch_size,deco_size,seq_len]  context线性变换而来
                  attn=  context_in  bmm output>>[batch_size,out_len,x_len]  batch_size, output_len, context_len}
                    masked_fill
        第二步 context=attn  bmm context_in>>[batch_size,out_len,x_len]  batch_size, output_len, context_len}
                    bmm
                    context: {batch_size, context_len, 2*enc_hidden_size]
                    context=[batch_size,output_len,2*enc_hidden_size]
        第三步  context在和output decoeder的输出做个拼接 
                    context+output=[batch_size,seq_len,2*hiddien_size+hidden_size]
                    经历一个线性变化
                    output=[batch_size,seq,hidden_size]

后面我们会介绍其他的三种注意力的实现方式

数据处理部分 

import os
import sys
import math
from collections import Counter
import numpy as np
import random

import torch
import torch.nn  as nn
import torch.nn.functional as F

import nltk
import jieba
#nltk.download('punkt')

"""
读入中英文数据

英文我们使用nltk的word tokenizer来分词,并且使用小写字母
中文我们直接使用单个汉字作为基本单元
中文用jieba分词 
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_data(in_file):
    cn=[]
    en=[]
    num_examples=0
    with open(in_file,'r')  as f:
        for line  in f:
            line=line.strip().split('\t')
            
            en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
            cn.append(["BOS"]+list(jieba.cut(line[1],cut_all=False))+["EOS"])
            
    return en,cn
    
train_file='/root/torch/nmt/en-cn/train.txt'
dev_file='/root/torch/nmt/en-cn/dev.txt'
train_en,train_cn=load_data(train_file)
#print(train_en)
#print(train_cn)
dev_en,dev_cn=load_data(dev_file)
"""
[['BOS', 'anyone', 'can', 'do', 'that', '.', 'EOS'], ['BOS', 'how', 'about', 'another', 'piece', 'of', 'cake', '?', 'EOS'], 
['BOS', 'she', 'married', 'him', '.', 'EOS'], ['BOS', 'i', 'do', "n't", 'like', 'learning', 'irregular', 'verbs', '.', 'EOS'], 
['BOS', 'it', "'s", 'a', 'whole', 'new', 'ball', 'game', 'for', 'me', '.', 'EOS'], ['BOS', 'he', "'s", 'sleeping', 'like', 'a', 'baby', '.', 'EOS'], 
['BOS', 'he', 'can', 'play', 'both', 'tennis', 'and', 'baseball', '.', 'EOS'], ['BOS', 'we', 'should', 'cancel', 'the', 'hike', '.', 'EOS'], 
['BOS', 'he', 'is', 'good', 'at', 'dealing', 'with', 'children', '.', 'EOS'], ['BOS', 'she', 'will', 'do', 'her', 'best', 'to', 'be', 'here', 'on', 'time', '.', 'EOS'], 
['BOS', 'why', 'are', 'you', 'so', 'good', 'at', 'cooking', '?', 'EOS'], ['BOS', 'he', 'has', 'recovered', 'from', 'his', 'bad', 'cold', '.', 'EOS'], 
['BOS', 'it', "'s", 'a', 'dead', 'end', '.', 'EOS'], ['BOS', 'i', 'rejected', 'the', 'offer', '.', 'EOS'], ['BOS', 'he', 'often', 'quotes', 'milton', '.', 'EOS'], 
['BOS', 'mommy', ',', 'may', 'i', 'go', 'swimming', '?', 'EOS'], ['BOS', 'miyazaki', 'is', 'not', 'what', 'it', 'used', 'to', 'be', '.', 'EOS'], ..........

[['BOS', '任何人', '都', '可以', '做到', '。', 'EOS'], ['BOS', '要', '不要', '再來', '一塊', '蛋糕', '?', 'EOS'], ['BOS', '她', '嫁给', '了', '他', '。', 'EOS'], 
['BOS', '我', '不', '喜欢', '学习', '不规则', '动词', '。', 'EOS'], ['BOS', '這對', '我', '來', '說', '是', '個', '全新', '的', '球類', '遊戲', '。', 'EOS'], 
['BOS', '他', '正', '睡着', ',', '像', '个', '婴儿', '一样', '。', 'EOS'], ['BOS', '他', '既会', '打网球', ',', '又', '会', '打', '棒球', '。', 'EOS'],
['BOS', '我們', '應該', '取消', '這次', '遠足', '。', 'EOS'], ['BOS', '他', '擅長', '應付', '小孩子', '。', 'EOS'],
['BOS', '她', '会', '尽量', '按时', '赶来', '的', '。', 'EOS'],.........................................................................................................................

"""

"""
构造单词表
word2count
word2index
index2word
一个句子一个句子统计



"""
UNK_IDX=0
PAD_IDX=1
def build_dict(sentences,max_words=50000):
    word_count=Counter()
    for sentence in sentences:
        for s in sentence:
            word_count[s]+=1
    ls=word_count.most_common(max_words)
    total_words=len(ls)+2
    word_dict={w[0]:index+2 for index,w in enumerate(ls)}
    
    word_dict["UNK"]=UNK_IDX
    word_dict["PAD"]=PAD_IDX
    
    return word_dict,total_words

    
"""
词表构建


"""
en_dic,en_total_words=build_dict(train_en)
cn_dic,cn_total_words=build_dict(train_cn)

"""
构建index  to  word
"""
inv_en_dict={v:k for  k,v in en_dic.items()}
inv_cn_dict={v:k for k,v in cn_dic.items()}

"""
把单词全部变成数字
"""
def encode(en_sentences,cn_sentences,en_dic,cn_dic,sort_by_len=True):
    """
    encode the sequences
    """
    length=len(en_sentences)
    out_en_sentences=[[en_dic.get(w,0) for  w in sent ] for sent in en_sentences]
    out_cn_sentences=[[cn_dic.get(w,0) for w in sent ]  for sent in cn_sentences]
    
    #将英文句子按照长度进行排序  先返回整个列表的索引  然后在将索引按照句子长度进行排序
    #x 就是前面列表的元素
    def len_argsort(seq):
        return sorted(range(len(seq)) , key=lambda x :len(seq[x]))
    
    #吧中文和英文按照同样的顺序排序 
    if sort_by_len:
        sorted_index=len_argsort(out_en_sentences)
        out_cn_sentences=[out_cn_sentences[i] for i in sorted_index]
        out_en_sentences=[out_en_sentences[i] for i in sorted_index]
        
    return out_en_sentences,out_cn_sentences

train_en,train_cn=encode(train_en,train_cn,en_dic,cn_dic)
dev_en,dev_cn=encode(dev_en,dev_cn,en_dic,cn_dic)

"""
把全部句子转为batch
"""

def get_minibatches(n,
  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值