用tensorflow搭建BiLSTM+Attention构成的seq2seq模型

概要

实践BilSTM及Attention在seq2seq模型中的应用。

tensorflow 2.0+,python3.8+

数据集来源:manything.org

模型目标:将英文翻译为德文

模型构建

导入必要库

import numpy as np
import tensorflow as tf
import re
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, Dense, Attention, Embedding, Concatenate, Bidirectional
from tensorflow.keras.callbacks import ModelCheckpoint

数据预处理

数据集下载: manythings.org英译德

//数据集预览
Go.	Geh.	CC-BY 2.0 (France) Attribution: tatoeba.org #2877272 (CM) & #8597805 (Roujin)
Hi.	Hallo!	CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #380701 (cburgmer)
Hi.	Grüß Gott!	CC-BY 2.0 (France) Attribution: tatoeba.org #538123 (CM) & #659813 (Esperantostern)
Run!	Lauf!	CC-BY 2.0 (France) Attribution: tatoeba.org #906328 (papabear) & #941078 (Fingerhut)
Run.	Lauf!	CC-BY 2.0 (France) Attribution: tatoeba.org #4008918 (JSakuragi) & #941078 (Fingerhut)
Wow!	Potzdonner!	CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #2122382 (Pfirsichbaeumchen)
Wow!	Donnerwetter!	CC-BY 2.0 (France) Attribution: tatoeba.org #52027 (Zifre) & #2122391 (Pfirsichbaeumchen)
Duck!	Kopf runter!	CC-BY 2.0 (France) Attribution: tatoeba.org #280158 (CM) & #9968521 (wolfgangth)
Fire!	Feuer!	CC-BY 2.0 (France) Attribution: tatoeba.org #1829639 (Spamster) & #1958697 (Tamy)
Help!	Hilfe!	CC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #575889 (MUIRIEL)
Help!	Zu Hülf!	CC-BY 2.0 (France) Attribution: tatoeba.org #435084 (lukaszpp) & #2122375 (Pfirsichbaeumchen)
Stay.	Bleib!	CC-BY 2.0 (France) Attribution: tatoeba.org #8907595 (CK) & #5344007 (wochenweise)
Stop!	Stopp!	CC-BY 2.0 (France) Attribution: tatoeba.org #448320 (CM) & #626467 (jakov)
Stop!	Anhalten!	CC-BY 2.0 (France) Attribution: tatoeba.org #448320 (CM) & #7481623 (Yorwba)
Wait!	Warte!	CC-BY 2.0 (France) Attribution: tatoeba.org #1744314 (belgavox) & #2122378 (Pfirsichbaeumchen)
Wait.	Warte.	CC-BY 2.0 (France) Attribution: tatoeba.org #3048304 (camilozeta) & #8597806 (Roujin)
Begin.	Fang an.	CC-BY 2.0 (France) Attribution: tatoeba.org #6102432 (mailohilohi) & #4942826 (Hans_Adler)
Do it.	Mache es!	CC-BY 2.0 (France) Attribution: tatoeba.org #6049386 (mailohilohi) & #899757 (Esperantostern)
def load_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        lines = f.read().split('\n')
    return lines

data_path = './deu.txt' # 数据集位置
lines = load_data(data_path) #按行读入数据

# 数据预处理
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()

for line in lines[:min(10000, len(lines) - 1)]:
    #line = re.sub('[!?.]','',line).lower() # 清除标点,转换大小写
    packs = line.split('\t')
    if len(packs) == 2:continue # 数据集中有一个特例需剔除
    input_text, target_text, _ = packs
    target_text = '\t' + target_text + '\n'
    input_texts.append(input_text) # 读入英文文本
    target_texts.append(target_text) # 读入德文文本

    for char in input_text:
        if char not in input_characters:
            input_characters.add(char) 
    for char in target_text:
        if char not in target_characters:
            target_characters.add(char) # 按字符读入

input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
num_encoder_tokens = len(input_characters) # 统计总字符种类
num_decoder_tokens = len(target_characters)
max_encoder_seq_length = max([len(txt) for txt in input_texts]) 
max_decoder_seq_length = max([len(txt) for txt in target_texts]) # 统计最长单个文本长度

input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)]) # 建立字符字典

encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens), dtype='float32')
decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype='float32')
decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens), dtype='float32')

for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):
    for t, char in enumerate(input_text):
        encoder_input_data[i, t, input_token_index[char]] = 1.0
    for t, char in enumerate(target_text):
        decoder_input_data[i, t, target_token_index[char]] = 1.0
        if t > 0:
            decoder_target_data[i, t - 1, target_token_index[char]] = 1.0 # 建立onehot类型数据,目标文本(德文)往前移动一位

LSTM模型构建

# 构建Seq2Seq模型
latent_dim = 256

# 编码器
encoder_inputs = Input(shape=(None, num_encoder_tokens)) # ((输入个数),词数量,词向量维度)
encoder = LSTM(latent_dim, return_state=True) # (None, LSTM维度)
encoder_outputs, state_h, state_c = encoder(encoder_inputs) # (None,最后h,h,c)
encoder_states = [state_h, state_c]

# 解码器
decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) # (None,词数量,词向量维度)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)
decoder_dense = Dense(num_decoder_tokens, activation='softmax') #(词向量维度)
decoder_outputs = decoder_dense(decoder_outputs)

# 定义模型
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# 编译模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

LSTM+Attention模型构建

用Attention层连接decoder_outputs和encoder_outputs,在attention中,通常将decoder_outputs参数放在前面,如此attention会更关注decoder_outputs。

以下回答来自ChatGPT

在 Seq2Seq 模型中,注意力机制通常有三个关键的输入:

1. **Query (Q)**: 来自解码器的当前时间步的输出。

2. **Key (K)**: 来自编码器所有时间步的输出。

3. **Value (V)**: 同样来自编码器所有时间步的输出。

这三个输入(Q、K、V)通过一些权重矩阵相乘和其他操作,产生一个加权的编码器输出,这个加权的输出将用于生成解码器的最终输出。

以下是简化的注意力机制的工作原理:

1. **计算注意力分数(Attention Scores)**:
   - 首先,计算 Query 和 Key 之间的相似度(通常使用点积或其他方法),得到注意力分数。

2. **计算注意力权重(Attention Weights)**:
   - 对注意力分数进行缩放,然后通过 Softmax 函数获得归一化的注意力权重。这使得所有权重的总和等于 1。

3. **加权求和(Weighted Sum)**:
   - 将编码器输出(Values)与注意力权重相乘,并将结果加和。这个加权和成为注意力机制的输出。

在 Keras 中,`Attention` 层的实现通常包括这些步骤,并且可以轻松地嵌入到模型中。在使用 `Attention` 层时,你通常会提供解码器的输出(Query)和编码器的输出(Key 和 Value),然后层将处理上述步骤,并返回加权和的结果。

在decoder_outputs_attention = Concatenate(axis=-1)([decoder_outputs, attention])中,为什么要将这两个连接起来

将 `decoder_outputs` 和 `attention` 连接起来的目的是引入注意力机制的信息,以便更好地指导模型生成当前时刻的输出。这种连接的作用是融合解码器的原始输出和注意力机制的权重信息,使得模型在生成预测时能够更灵活地考虑输入序列的不同部分。

具体而言,连接的操作可以有助于以下几个方面:

1. **融合信息**:通过连接,模型可以将解码器的原始输出与计算得到的注意力权重信息结合起来。这有助于模型更好地融合输入序列中不同部分的信息,从而提高生成的输出的准确性。

2. **引入上下文信息**:连接操作使得每个解码器时刻的输出都能够考虑到整个输入序列的上下文信息。这样,模型在生成每个时间步的输出时可以更好地关注输入序列中与当前预测相关的部分。

3. **增加模型表达能力**:通过引入注意力机制,模型的表达能力得到了增强,因为它可以动态地调整对输入序列中不同部分的关注程度。连接操作在这个过程中起到了将这两方面信息融合在一起的作用。

需要注意的是,这里连接是在最后一个轴上进行的,这意味着 `decoder_outputs` 和 `attention` 的最后一个轴的大小必须相同,以便顺利进行连接。这样,连接后的张量将包含解码器的原始输出和注意力机制的信息,可以传递给后续的层(例如 Dense 层)进行最终的输出生成。
# 构建Seq2Seq模型
latent_dim = 256

# 编码器
encoder_inputs = Input(shape=(None, num_encoder_tokens)) # ((输入个数),词数量,词向量维度)
encoder = LSTM(latent_dim, return_state=True) # (None, LSTM维度)
encoder_outputs, state_h, state_c = encoder(encoder_inputs) # (None,最后h,h,c)
encoder_states = [state_h, state_c]

# 解码器
decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) # (None,词数量,词向量维度)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)

# 注意力层
attention = Attention()([decoder_outputs, encoder_outputs])

# 将注意力层的输出和解码器的输出连接起来
decoder_outputs_attention = Concatenate(axis=-1)([decoder_outputs, attention])

# 使用Dense层进行最终输出
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs_attention)

# 定义模型
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# 编译模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

BiLSTM+Attention模型构建

关于BiLSTM的构建,只能应用于encoder编码器上,因为decoder解码器不能反向。

# 构建Seq2Seq模型
latent_dim = 256

# 编码器
encoder_inputs = Input(shape=(None, num_encoder_tokens)) # ((输入个数),词数量,词向量维度)
encoder = Bidirectional(LSTM(128, return_state=True)) # (None, LSTM维度)
encoder_outputs, forward_h, forward_c, backward_h, backward_c = encoder(encoder_inputs) # (None,最后h,h,c)
state_h = Concatenate()([forward_h, backward_h])
state_c = Concatenate()([forward_c, backward_c])
encoder_states = [state_h, state_c]

# 解码器
decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True) # (None,词数量,词向量维度)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs, initial_state=encoder_states)

# 注意力层
attention = Attention()([decoder_outputs, encoder_outputs])

# 将注意力层的输出和解码器的输出连接起来
decoder_outputs_attention = Concatenate(axis=-1)([decoder_outputs, attention])

# 使用Dense层进行最终输出
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs_attention)

# 定义模型
model = Model([encoder_inputs, decoder_inputs], decoder_outputs)

# 编译模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])

model.summary()

模型训练

# 训练模型
epochs = 40
batch_size = 64

# 使用checkpoint保存最佳模型
checkpoint = ModelCheckpoint('s2s_BiLSTM_with_Attetion.h5', save_best_only=True)

history = model.fit(
    [encoder_input_data, decoder_input_data],
    decoder_target_data,
    batch_size=batch_size,
    epochs=epochs,
    validation_split=0.2,
    callbacks=[checkpoint]
)

import matplotlib.pyplot as plt

# 获取训练过程中的损失值和准确率
train_loss = history.history['loss']
train_acc = history.history['accuracy']
val_loss = history.history['val_loss']
val_acc = history.history['val_accuracy']

# 绘制损失值曲线
epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='lower right')
plt.title('Training and Validation Loss')

# 绘制准确率曲线
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='upper right')
plt.title('Training and Validation Accuracy')

plt.show()

数据测试

# 加载最佳模型
#model.load_weights('s2s_LSTM.h5')

# 使用模型进行翻译
encoder_model = Model(encoder_inputs, encoder_states)

decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]

decoder_outputs, state_h, state_c = decoder_lstm(
    decoder_inputs, initial_state=decoder_states_inputs
)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model(
    [decoder_inputs] + decoder_states_inputs, [decoder_outputs] + decoder_states
)

# 将序列转换为文本
def decode_sequence(input_seq):
    # 编码输入序列为状态向量
    states_value = encoder_model.predict(input_seq)

    # 生成目标序列的空序列(长度为1)
    target_seq = np.zeros((1, 1, num_decoder_tokens))
    # 将目标序列的第一个字符设置为开始字符
    target_seq[0, 0, target_token_index['\t']] = 1.0

    # 用于存储生成的目标序列
    decoded_sentence = ''

    stop_condition = False
    while not stop_condition:
        # 通过解码器预测当前字符的概率分布
        output_tokens, h, c = decoder_model.predict([target_seq] + states_value,verbose=0)

        # 采样一个字符
        sampled_token_index = np.argmax(output_tokens[0, -1, :])
        sampled_char = target_characters[sampled_token_index]
        decoded_sentence += sampled_char

        # 如果达到序列的最大长度或者生成了停止字符,停止生成
        if sampled_char == '\n' or len(decoded_sentence) > max_decoder_seq_length:
            stop_condition = True

        # 更新目标序列(用于下一步的解码)
        target_seq = np.zeros((1, 1, num_decoder_tokens))
        target_seq[0, 0, sampled_token_index] = 1.0

        # 更新解码器的状态
        states_value = [h, c]

    return decoded_sentence

#def text_to_sequences():

for i in range(100):
    # 选择一个输入文本进行翻译
    input_text = encoder_input_data[i:i+1]

    # 进行翻译
    decoded_sentence = decode_sequence(input_text)

    # 打印翻译结果
    print('Input sentence:', input_texts[i])
    print('Decoded sentence:', decoded_sentence)

结果展示

来自于LSTM构建的模型
在这里插入图片描述

Input sentence: go.
Decoded sentence: auf.

Input sentence: hi.
Decoded sentence: auf ein.

Input sentence: hi.
Decoded sentence: auf ein.

Input sentence: run!
Decoded sentence: warte.

Input sentence: run.
Decoded sentence: lauf.

Input sentence: wow!
Decoded sentence: werde auf.

Input sentence: wow!
Decoded sentence: werde auf.

Input sentence: duck!
Decoded sentence: auf auf.

Input sentence: fire!
Decoded sentence: auf auf.

Input sentence: help!
Decoded sentence: hall.

Input sentence: help!
Decoded sentence: hall.

Input sentence: stay.
Decoded sentence: auffen.

Input sentence: stop!
Decoded sentence: fange.

Input sentence: stop!
Decoded sentence: fange.

Input sentence: wait!
Decoded sentence: warten.
...

Input sentence: go.
Decoded sentence: auf.

小结

关于模型的构建目前我看到三种方法:
1.重写模型构建类
2.类似这篇博客中的嵌套方法
3.利用Model.add()方法

目前对于前两种方法不是很熟悉。

模型的效果并不是很好,accuracy只有20%不到,大多数测试中翻译也是错误的。同时在多次训练中出现结果末尾相同字符多次重复的问题。估计是数据处理优化没做好,同时优化方法不够。

  • 6
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
下面是一个基于 PyTorch 的 BiLSTM Attention 模型的示例代码: ```python import torch import torch.nn as nn class BiLSTMAttention(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(BiLSTMAttention, self).__init__() self.hidden_size = hidden_size self.num_classes = num_classes self.lstm = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) self.fc = nn.Linear(hidden_size*2, num_classes) self.attention_layer = nn.Sequential( nn.Linear(hidden_size*2, hidden_size), nn.Tanh(), nn.Linear(hidden_size, 1) ) def forward(self, x): lstm_out, _ = self.lstm(x) attention_weights = self.attention_layer(lstm_out).squeeze(2) attention_weights = torch.softmax(attention_weights, dim=1) weighted_lstm_out = torch.bmm(lstm_out.permute(0,2,1), attention_weights.unsqueeze(2)).squeeze(2) out = self.fc(weighted_lstm_out) return out ``` 在这个模型中,我们使用 `nn.LSTM` 来实现 BiLSTM,使用 `nn.Linear` 实现全连接层,使用 `nn.Sequential` 实现 Attention 层。 在 `forward` 方法中,我们首先使用 BiLSTM 对输入进行编码,然后通过 Attention 层计算每个时间步的注意力权重,将这些权重加权求和得到加权后的输出向量,最后通过全连接层输出分类结果。 这个模型的输入 `x` 是一个形状为 `(batch_size, seq_len, input_size)` 的张量,其中 `batch_size` 表示批次大小,`seq_len` 表示序列长度,`input_size` 表示输入特征维度。输出是一个形状为 `(batch_size, num_classes)` 的张量,其中 `num_classes` 表示分类类别数。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值