Pytorch学习(7)——Seq2Seq与Attention

视频地址:https://www.bilibili.com/video/BV1vz4y1R7Mm?p=7

Sqe2Seq, Attention

先去
https://github.com/ZeweiChu/PyTorch-Course/tree/master/notebooks
下载数据集(nmt文件夹)

import os
import sys
import math
from collections import Counter
import numpy as np
import random

import torch
import torch.nn as nn

import nltk

print("torch", torch.__version__)
print("nltk", nltk.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch 1.2.0+cu92
nltk 3.5
def load_data(in_file):
    """读取训练数据,分词,返回结果"""
    cn = []
    en = []
    num_examples = 0
    with open(in_file, 'r', encoding='utf8') as f:
        for line in f:
            line = line.strip().split('\t')
            en.append(["BOS"] + nltk.word_tokenize(line[0].lower()) + ["EOS"])
            cn.append(["BOS"] + [c for c in line[1]] + ["EOS"])
    return en, cn

train_file = "nmt/en-cn/train.txt"
dev_file = "nmt/en-cn/dev.txt"
train_en, train_cn = load_data(train_file)
dev_en, dev_cn = load_data(dev_file)
train_en[:3]
[['BOS', 'anyone', 'can', 'do', 'that', '.', 'EOS'],
 ['BOS', 'how', 'about', 'another', 'piece', 'of', 'cake', '?', 'EOS'],
 ['BOS', 'she', 'married', 'him', '.', 'EOS']]
UNK_IDX = 0
PAD_IDX = 1
def build_dict(sentences, max_words=50000):
    """构建词表"""
    word_count = Counter()
    for sentence in sentences:
        for s in sentence:
            word_count[s] += 1
    ls = word_count.most_common(max_words)
    total_words = len(ls) + 2
    word_dict = {w[0]: index+2 for index, w in enumerate(ls)}
    word_dict["UNK"] = UNK_IDX
    word_dict["PAD"] = PAD_IDX
    return word_dict, total_words

en_dict, en_total_words = build_dict(train_en)
cn_dict, cn_total_words = build_dict(train_cn)
inv_en_dict = {v:k for k,v in en_dict.items()}
inv_cn_dict = {v:k for k,v in cn_dict.items()}
def encode(en_sentences, cn_sentences, en_dict, cn_dict, sort_by_len=True):
    """对句子进行编码"""
    lenght = len(en_sentences)
    out_en_sentences = [[en_dict.get(w, UNK_IDX) for w in sent] for sent in en_sentences]
    out_cn_sentences = [[cn_dict.get(w, UNK_IDX) for w in sent] for sent in cn_sentences]
    
    def len_sort(seq):
        """对序号进行排序,因为下面要用两次序号"""
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_sort(out_en_sentences)
        out_en_sentences = [out_en_sentences[i] for i in sorted_index]
        out_cn_sentences = [out_cn_sentences[i] for i in sorted_index]
        
    return out_en_sentences, out_cn_sentences

train_en, train_cn = encode(train_en, train_cn, en_dict, cn_dict)
dev_en, dev_cn = encode(dev_en, dev_cn, en_dict, cn_dict)
k = 10000
print(" ".join([inv_en_dict[i] for i in train_en[k]]))
print(" ".join([inv_cn_dict[i] for i in train_cn[k]]))
BOS for what purpose did he come here ? EOS
BOS 他 来 这 里 的 目 的 是 什 么 ? EOS
def get_minibatches(n, minibatch_size, shuffle):
    """返回[[3,4,5], [0,1,2], [6,7,8]]这种数据"""
    idx_list = np.arange(0, n, minibatch_size)
    if shuffle:
        np.random.shuffle(idx_list)
    minibatches = []
    for idx in idx_list:
        minibatches.append(np.arange(idx, min(idx+minibatch_size,n)))
    return minibatches

def prepare_data(seqs):
    """将长度不等的序列补齐,补0"""
    lengths = [len(seq) for seq in seqs]
    n_samples = len(seqs)
    max_len = np.max(lengths)
    
    x = np.zeros((n_samples, max_len)).astype(np.int32)
    x_lengths = np.array(lengths).astype(np.int32)
    for idx, seq in enumerate(seqs):
        x[idx, :lengths[idx]] = seq
    return x, x_lengths  # 保留x_lengths用于mask

def gen_examples(en_sentences, cn_sentences, batch_size, shuffle=False):
    minibatches = get_minibatches(len(en_sentences), batch_size, shuffle)
    all_ex = []
    for minibatch in minibatches:
        # 取出每个mini-batch对应的句子
        mb_en_sentences = [en_sentences[t] for t in minibatch]
        mb_cn_sentences = [cn_sentences[t] for t in minibatch]
        mb_x, mb_x_len = prepare_data(mb_en_sentences)
        mb_y, mb_y_len = prepare_data(mb_cn_sentences)
        all_ex.append((mb_x, mb_x_len, mb_y, mb_y_len))
    return all_ex

batch_size = 64
train_data = gen_examples(train_en, train_cn, batch_size, shuffle=True)
dev_data = gen_examples(dev_en, dev_cn, batch_size, shuffle=False)

没有Attention的简单模型

class SimpleEncoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super(SimpleEncoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, lengths):
        # 把batch里面的seq按照长度排序, 是pack_padded_sequence要求的
        sorted_len, sorted_idx = lengths.sort(0, descending=True)
        x_sorted = x[sorted_idx]
        embedded = self.dropout(self.embed(x_sorted))
        
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
        packed_out, hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()
        
        return out, hid[[-1]]
        
class SimpleDecoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, dropout=0.2):
        super(SimpleDecoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.rnn = nn.GRU(2*hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(2*hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, y, y_lengths, hid):
        sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_idx]  # 隐状态也要调顺序
        embedded = self.dropout(self.embed(y_sorted))  # batch_size, y_lengths, hidden_size
        hid = hid[:, sorted_idx]

        embedded = torch.cat([embedded, hid.squeeze(0).unsqueeze(1).expand_as(embedded)], 2)  # batch_size, y_lengths, hidden_size*2
        
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
        packed_out, hid2 = self.rnn(packed_embedded, hid)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()  # batch_size, y_lengths, hidden_size
        
        hid2 = hid2[:, original_idx.long()].contiguous()  # 1, batch_size, hidden_size, 隐状态本来就没有长度维度
        
        out = torch.cat([out, hid.squeeze(0).unsqueeze(1).expand_as(out)], 2)
        out = self.fc(out)  # batch_size, y_lengths, vocab_size
        out = nn.functional.log_softmax(out, -1)  # log_softmax默认对第0个维度进行softmax,用-1指定为最后1维
        
        return out, hid2
    
class SimpleSeq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(SimpleSeq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x, x_lengths, y, y_lengths):
        encoder_out, hid = self.encoder(x, x_lengths)
        output, hid = self.decoder(y, y_lengths, hid)
        return output, None
    
    def translate(self, x, x_lengths, y, max_length=10):
        encoder_out, hid = self.encoder(x, x_lengths)
        preds = []
        batch_size = x.shape[0]
        attns = []
        for i in range(max_length):
            output, pid = self.decoder(y=y, y_lengths=torch.ones(batch_size).long().to(device), hid=hid)
            y = output.max(2)[1].view(batch_size, 1)
            preds.append(y)
        
        return torch.cat(preds, 1), None
class LanguageModelCriterion(nn.Module):
    def __init__(self):
        super(LanguageModelCriterion, self).__init__()
    
    def forward(self, input, target, mask):
        # input: (batch_size * seq_len) * vocab_size
        input = input.contiguous().view(-1, input.size(2))
        # target: batch_size * 1
        target = target.contiguous().view(-1, 1)
        mask = mask.contiguous().view(-1, 1)
        output = -input.gather(1, target) * mask
        output = torch.sum(output) / torch.sum(mask)
        
        return output
dropout = 0.2
hidden_size = 100
encoder = SimpleEncoder(vocab_size=en_total_words, hidden_size=hidden_size, dropout=dropout)
decoder = SimpleDecoder(vocab_size=cn_total_words, hidden_size=hidden_size, dropout=dropout)
model = SimpleSeq2Seq(encoder, decoder)
model = model.to(device)

loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
def evaluate(model, data):
    model.eval()
    total_num_words = 0
    total_loss = 0
    with torch.no_grad():
        for it, (x, x_len, y, y_len) in enumerate(data):
            x = torch.from_numpy(x).to(device).long()
            x_len = torch.from_numpy(x_len).to(device).long()
            input = torch.from_numpy(y[:, :-1]).to(device).long()
            output = torch.from_numpy(y[:, 1:]).to(device).long()
            y_len = torch.from_numpy(y_len - 1).to(device).long()
            
            pred, attn = model(x, x_len, input, y_len)

            out_mask = torch.arange(y_len.max().item(), device=device)[None, :] < y_len[:, None]
            out_mask = out_mask.float()

            loss = loss_fn(pred, output, out_mask)

            num_words = torch.sum(y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words

    print(f"Evaluation loss {total_loss/total_num_words}")
def train(model, data, num_epochs=30):
    for epoch in range(num_epochs):
        model.train()
        total_num_words = 0
        total_loss = 0
        for it, (x, x_len, y, y_len) in enumerate(data):
            x = torch.from_numpy(x).to(device).long()
            x_len = torch.from_numpy(x_len).to(device).long()
            input = torch.from_numpy(y[:, :-1]).to(device).long()
            output = torch.from_numpy(y[:, 1:]).to(device).long()
            y_len = torch.from_numpy(y_len - 1).to(device).long()
            
            pred, attn = model(x, x_len, input, y_len)
            
            out_mask = torch.arange(y_len.max().item(), device=device)[None, :] < y_len[:, None]
            out_mask = out_mask.float()
            
            loss = loss_fn(pred, output, out_mask)
            
            num_words = torch.sum(y_len).item()
            total_loss += loss.item() * num_words
            total_num_words += num_words
            
            # 更新模型
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 5.)
            optimizer.step()
            
            if it % 20 == 0:
                print(f"Epoch {epoch} iteration {it} loss {loss.item()}")
            
        print(f"Epoch {epoch} Tranning loss {total_loss/total_num_words}")
        evaluate(model, dev_data)

train(model, train_data, num_epochs=2)
Epoch 0 iteration 0 loss 4.1712327003479
Epoch 0 iteration 20 loss 4.116885185241699
Epoch 0 iteration 40 loss 3.6872975826263428
Epoch 0 iteration 60 loss 3.561073064804077
Epoch 0 iteration 80 loss 3.853266954421997
Epoch 0 iteration 100 loss 3.8749101161956787
Epoch 0 iteration 120 loss 3.787318468093872
Epoch 0 iteration 140 loss 3.8778975009918213
Epoch 0 iteration 160 loss 3.499333143234253
Epoch 0 iteration 180 loss 3.4058444499969482
Epoch 0 iteration 200 loss 3.708536386489868
Epoch 0 iteration 220 loss 3.7492809295654297
Epoch 0 Tranning loss 4.090564886316073
Evaluation loss 3.973093799748185
Epoch 1 iteration 0 loss 3.813424825668335
Epoch 1 iteration 20 loss 3.7274014949798584
Epoch 1 iteration 40 loss 3.332293748855591
Epoch 1 iteration 60 loss 3.245396375656128
Epoch 1 iteration 80 loss 3.549337863922119
Epoch 1 iteration 100 loss 3.5778181552886963
Epoch 1 iteration 120 loss 3.4655778408050537
Epoch 1 iteration 140 loss 3.6115944385528564
Epoch 1 iteration 160 loss 3.233412027359009
Epoch 1 iteration 180 loss 3.12416672706604
Epoch 1 iteration 200 loss 3.426990032196045
Epoch 1 iteration 220 loss 3.5066943168640137
Epoch 1 Tranning loss 3.796283439917883
Evaluation loss 3.787926131268771
print(" ".join([inv_en_dict[i] for i in dev_en[0]]))
print("".join([inv_cn_dict[i] for i in dev_cn[0]]))
x = torch.from_numpy(np.array([dev_en[0], dev_en[1]])).long().to(device)
x_len = torch.from_numpy(np.array([len(dev_en[0]), len(dev_en[1])])).long().to(device)
bos = torch.Tensor([[cn_dict["BOS"]], [cn_dict["BOS"]]]).long().to(device)

translation, attn = model.translate(x, x_len, bos)
translation = [[inv_cn_dict[i] for i in sentence] for sentence in translation.data.cpu().numpy()]
print(translation)
BOS look around . EOS
BOS四处看看。EOS
tensor([[2],
        [2]])
tensor([5, 5])
[['她', '的', '人', '都', '喜', '欢', '迎', '是', '我', '們'], ['汤', '姆', '在', '这', '个', '人', '都', '是', '汤', '姆']]

改进版

在这里插入图片描述

class Attention(nn.Module):
    """
    Luong Attention.
    根据 context vectors 和当前的输出 hidden_states,计算输出
    """
    def __init__(self, encoder_hidden_size, decoder_hidden_size):
        super(Attention, self).__init__()
        self.encoder_hidden_size = encoder_hidden_size
        self.decoder_hidden_size = decoder_hidden_size
        
        self.linear_in = nn.Linear(encoder_hidden_size*2, decoder_hidden_size, bias=False)
        self.linear_out = nn.Linear(encoder_hidden_size*2 + decoder_hidden_size, decoder_hidden_size)
        
    def forward(self, output, context, mask):
        # output: batch_size, output_len, decoder_hidden_size
        # context: batch_size, context_len, encoder_hidden_size
        
        batch_size = output.size(0)
        output_len = output.size(1)
        input_len = context.size(1)
        
        # context_in.transpose(1,2): batch_size, dec_hidden_size, context_len 
        # output: batch_size, output_len, dec_hidden_size
        context_in = self.linear_in(context.view(batch_size * input_len, -1)).view(batch_size, input_len, -1)
        attn = torch.bmm(output, context_in.transpose(1, 2))  # batch matrix-matrix product
        
        attn.data.masked_fill(mask, -1e6)  # 把Mask的位置设置成非常小的值,不影响softmax
        
        # 一直根据上面的公式计算,我看不下去了……
        attn = nn.functional.softmax(attn, dim=2)
        
        context = torch.bmm(attn, context)
        output = torch.cat((context, output), dim=2)
        
        output = output.view(batch_size * output_len, -1)
        output = torch.tanh(self.linear_out(output))
        output = output.view(batch_size, output_len, -1)
        return output, attn
class Encoder(nn.Module):
    def __init__(self, vocab_size, embed_size, encoder_hidden_size, decoder_hidden_size, dropout=0.2):
        super(Encoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.rnn = nn.GRU(embed_size, encoder_hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(encoder_hidden_size*2, decoder_hidden_size)
        
    def forward(self, x, lengths):
        sorted_len, sorted_idx = lengths.sort(0, descending=True)
        x_sorted = x[sorted_idx]
        embedded = self.dropout(self.embed(x_sorted))
        
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, sorted_len.long().cpu().data.numpy(), batch_first=True)
        packed_out, hid = self.rnn(packed_embedded)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        
        _, original_idx = sorted_idx.sort(0, descending=False)
        out = out[original_idx.long()].contiguous()
        hid = hid[:, original_idx.long()].contiguous()
        
        hid = torch.cat([hid[-2], hid[-1]], dim=1)
        hid = torch.tanh(self.fc(hid)).unsqueeze(0)
        
        return out, hid


class Decoder(nn.Module):
    def __init__(self, vocab_size, embed_size, encoder_hidden_size, decoder_hidden_size, dropout=0.2):
        super(Decoder, self).__init__()
        self.embed = nn.Embedding(vocab_size, embed_size)
        self.attention = Attention(encoder_hidden_size, decoder_hidden_size)
        self.rnn = nn.GRU(embed_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(decoder_hidden_size, vocab_size)
        self.dropout = nn.Dropout(dropout)
        
    def create_mask(self, x_len, y_len):
        # a mask of shape x_len * y_len
        device = x_len.device
        max_x_len = x_len.max()
        max_y_len = y_len.max()
        x_mask = torch.arange(max_x_len, device=x_len.device)[None, :] < x_len[:, None]
        y_mask = torch.arange(max_y_len, device=x_len.device)[None, :] < y_len[:, None]
        mask = (~(x_mask[:, :, None] * y_mask[:, None, :]))
        return mask
        
    def forward(self, context, context_lengths, y, y_lengths, hid):
        sorted_len, sorted_idx = y_lengths.sort(0, descending=True)
        y_sorted = y[sorted_idx]  # 隐状态也要调顺序
        y_sorted = self.dropout(self.embed(y_sorted))  # batch_size, y_lengths, hidden_size
        hid = hid[:, sorted_idx]

        packed_seq = nn.utils.rnn.pack_padded_sequence(y_sorted, sorted_len.long().cpu().data.numpy(), batch_first=True)
        out, hid2 = self.rnn(packed_seq, hid)
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(out, batch_first=True)
        
        _, original_idx = sorted_idx.sort(0, descending=False)
        
        output_seq = unpacked[original_idx.long()].contiguous()  # batch_size, y_lengths, hidden_size
        hid = hid[:, original_idx.long()].contiguous()  # 1, batch_size, hidden_size, 隐状态本来就没有长度维度
        
        mask = self.create_mask(y_lengths, context_lengths)
        output, attention = self.attention(output_seq, context, mask)

        output = nn.functional.log_softmax(self.fc(output), -1)  # log_softmax默认对第0个维度进行softmax,用-1指定为最后1维
        
        return output, hid, attention
class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x, x_lengths, y, y_lengths):
        encoder_out, hid = self.encoder(x, x_lengths)
        output, hid, attn = self.decoder(
                                   context=encoder_out, 
                                   context_lengths=x_lengths, 
                                   y=y,
                                   y_lengths=y_lengths, 
                                   hid=hid)
        return output, attn

    def translate(self, x, x_lengths, y, max_length=100):
        encoder_out, hid = self.encoder(x, x_lengths)
        preds = []
        batch_size = x.shape[0]
        attns = []
        for i in range(max_length):
            output, hid, attn = self.decoder(context=encoder_out,
                                             context_lengths=x_lengths,
                                             y=y, 
                                             y_lengths=torch.ones(batch_size).long().to(device), hid=hid)
            y = output.max(2)[1].view(batch_size, 1)
            preds.append(y)
            attns.append(attn)
        
        return torch.cat(preds, 1), torch.cat(attns, 1)
"""训练"""

dropout = 0.2
embed_size = hidden_size = 100
encoder = Encoder(vocab_size=en_total_words,
                 embed_size=embed_size,
                 encoder_hidden_size=hidden_size,
                 decoder_hidden_size=hidden_size,
                 dropout=dropout)
decoder = Decoder(vocab_size=cn_total_words,
                 embed_size=embed_size,
                 encoder_hidden_size=hidden_size,
                 decoder_hidden_size=hidden_size,
                 dropout=dropout)
model = Seq2Seq(encoder, decoder)
model = model.to(device)
loss_fn = LanguageModelCriterion().to(device)
optimizer = torch.optim.Adam(model.parameters())
train(model, train_data, num_epochs=3)
Epoch 0 iteration 0 loss 8.088303565979004
Epoch 0 iteration 20 loss 6.1903395652771
...
Epoch 0 iteration 220 loss 4.88083028793335
Epoch 0 Tranning loss 5.531404307857315
Evaluation loss 5.054636112028756
Epoch 1 iteration 0 loss 5.031433582305908
Epoch 1 iteration 20 loss 5.012241363525391
...
Epoch 1 iteration 220 loss 4.389835357666016
Epoch 1 Tranning loss 4.855371610606542
Evaluation loss 4.594973376533105
Epoch 2 iteration 0 loss 4.569258213043213
Epoch 2 iteration 20 loss 4.4955363273620605
...
Epoch 2 iteration 220 loss 4.008622646331787
Epoch 2 Tranning loss 4.423912357264402
Evaluation loss 4.228922175465478
k = 120
print(" ".join([inv_en_dict[i] for i in dev_en[k]]))
print("".join([inv_cn_dict[i] for i in dev_cn[k]]))
x = torch.from_numpy(np.array([dev_en[k]])).long().to(device)
x_len = torch.from_numpy(np.array([len(dev_en[k])])).long().to(device)
bos = torch.Tensor([[cn_dict["BOS"]]]).long().to(device)

translation, attn = model.translate(x, x_len, bos)
translation = [[inv_cn_dict[i] for i in sentence] for sentence in translation.data.cpu().numpy()]
print(translation)
BOS i like your room . EOS
BOS我喜欢你的房间。EOS
[['我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我', '不', '是', '我']]
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值