keras 的 example 文件 cnn_seq2seq.py 解析

该代码是实现一个翻译功能,好像是英语翻译为法语,嗯,我看不懂法语

首先这个代码有一个bug,本人提交了一个pull request来修复,

https://github.com/keras-team/keras/pull/13863/commits/fd44e03a9d17c05aaecc620f8d88ef0fd385254b

但由于官方长久不维护,所以至今尚未合并,

需要把第68行改为:

input_text, target_text, _ = line.split('\t')

然后根据训练数据,对字母进行编码,其中target_token_index中添加了两个字符,开始符号 '\t' 和结束符合 '\n':

print(input_token_index)
{' ': 0, '!': 1, '$': 2, '%': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '0': 9, '1': 10, '2': 11, '3': 12, '5': 13, '6': 14, '7': 15, '8': 16, '9': 17, ':': 18, '?': 19, 'A': 20, 'B': 21, 'C': 22, 'D': 23, 'E': 24, 'F': 25, 'G': 26, 'H': 27, 'I': 28, 'J': 29, 'K': 30, 'L': 31, 'M': 32, 'N': 33, 'O': 34, 'P': 35, 'Q': 36, 'R': 37, 'S': 38, 'T': 39, 'U': 40, 'V': 41, 'W': 42, 'Y': 43, 'a': 44, 'b': 45, 'c': 46, 'd': 47, 'e': 48, 'f': 49, 'g': 50, 'h': 51, 'i': 52, 'j': 53, 'k': 54, 'l': 55, 'm': 56, 'n': 57, 'o': 58, 'p': 59, 'q': 60, 'r': 61, 's': 62, 't': 63, 'u': 64, 'v': 65, 'w': 66, 'x': 67, 'y': 68, 'z': 69}
print(target_token_index)
{'\t': 0, '\n': 1, ' ': 2, '!': 3, '$': 4, '%': 5, '&': 6, "'": 7, '(': 8, ')': 9, ',': 10, '-': 11, '.': 12, '0': 13, '1': 14, '2': 15, '3': 16, '5': 17, '8': 18, '9': 19, ':': 20, '?': 21, 'A': 22, 'B': 23, 'C': 24, 'D': 25, 'E': 26, 'F': 27, 'G': 28, 'H': 29, 'I': 30, 'J': 31, 'K': 32, 'L': 33, 'M': 34, 'N': 35, 'O': 36, 'P': 37, 'Q': 38, 'R': 39, 'S': 40, 'T': 41, 'U': 42, 'V': 43, 'Y': 44, 'a': 45, 'b': 46, 'c': 47, 'd': 48, 'e': 49, 'f': 50, 'g': 51, 'h': 52, 'i': 53, 'j': 54, 'k': 55, 'l': 56, 'm': 57, 'n': 58, 'o': 59, 'p': 60, 'q': 61, 'r': 62, 's': 63, 't': 64, 'u': 65, 'v': 66, 'x': 67, 'y': 68, 'z': 69, '\xa0': 70, '«': 71, '»': 72, 'À': 73, 'Ç': 74, 'É': 75, 'Ê': 76, 'à': 77, 'â': 78, 'ç': 79, 'è': 80, 'é': 81, 'ê': 82, 'ë': 83, 'î': 84, 'ï': 85, 'ô': 86, 'ù': 87, 'û': 88, 'œ': 89, '\u2009': 90, '’': 91, '\u202f': 92}

对,这个演示示例中不是对word进行编码,而是对字母进行编码,

至于原因,我分析应该是这样的,字母数量比较少,这个索引数也不过只有70个而已,但如果对单词进行编码,那随随便便就上千个,维度超大,后面再运算的时候,需要占用极大的内存和GPU

 

然后对输入输出的句子手动进行one-hot编码:

在预处理中,target_text 的首位补了一个'\t',代表句子开始了,末尾补了一个'\n',代表句子结束了

输入数据的尺寸为:

encoder_input_data.shape (10000, 16, 70)
decoder_input_data.shape (10000, 59, 93)
decoder_target_data.shape (10000, 59, 93)

而这个decoder_input_data 和 decoder_target_data 都是翻译后的句子,只不过 decoder_target_data 比 decoder_input_data 提前一位,decoder_input_data 的第一位是 '\t', 第二位才是真实内容,而 decoder_target_data 的第一位直接就是真实内容了。

为什么会把翻译的结果作为模型的输入?

因为在训练模型时,下一位的输出会依赖上一位的值,而在神经网络最开始的时候,如果预测的第一位错了,在预测第二位的时候,就会有一个错误的输入,我们这时候根据一个错误的输入去优化神经网络是走在了错误的方向,所以我们会辅助提供一个正确的值,这样神经网络才是向正确的方向优化

 

神经网络结构

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_2 (InputLayer)            (None, None, 93)     0
__________________________________________________________________________________________________
input_1 (InputLayer)            (None, None, 70)     0
__________________________________________________________________________________________________
conv1d_4 (Conv1D)               (None, None, 256)    71680       input_2[0][0]
__________________________________________________________________________________________________
conv1d_1 (Conv1D)               (None, None, 256)    54016       input_1[0][0]
__________________________________________________________________________________________________
conv1d_5 (Conv1D)               (None, None, 256)    196864      conv1d_4[0][0]
__________________________________________________________________________________________________
conv1d_2 (Conv1D)               (None, None, 256)    196864      conv1d_1[0][0]
__________________________________________________________________________________________________
conv1d_6 (Conv1D)               (None, None, 256)    196864      conv1d_5[0][0]
__________________________________________________________________________________________________
conv1d_3 (Conv1D)               (None, None, 256)    196864      conv1d_2[0][0]
__________________________________________________________________________________________________
dot_1 (Dot)                     (None, None, None)   0           conv1d_6[0][0]
                                                                 conv1d_3[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None)   0           dot_1[0][0]
__________________________________________________________________________________________________
dot_2 (Dot)                     (None, None, 256)    0           activation_1[0][0]
                                                                 conv1d_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, None, 512)    0           dot_2[0][0]
                                                                 conv1d_6[0][0]
__________________________________________________________________________________________________
conv1d_7 (Conv1D)               (None, None, 64)     98368       concatenate_1[0][0]
__________________________________________________________________________________________________
conv1d_8 (Conv1D)               (None, None, 64)     12352       conv1d_7[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, None, 93)     6045        conv1d_8[0][0]
==================================================================================================
Total params: 1,029,917
Trainable params: 1,029,917
Non-trainable params: 0
__________________________________________________________________________________________________

在预测的时候,encoder_input_data 就是输入的句子,decoder_input_data 是一个除第一位设置为开始符号'\t'外,其余位均为0的结构,在预测出第一位 decoder_target_data 后,把预测的字符追加到 decoder_input_data 后面一位,然后通过 for 循环预测下一位,以此类推,直到预期长度

因为预测出的结果为编号,需要反向索引为字符,而在反向索引时如果遇到结束符 '\n',就表示句子结束,得到了完整的预测结果

 

____________________________________________________

代码 lstm_seq2seq.py 的数据预处理和上面一致,就不另外写一篇了,神经网络结构为:

______________________________________________________________________________________________________________
Layer (type)                Output Shape                                     Param #        Connected to      
==============================================================================================================
input_1 (InputLayer)        (None, None, 70)                                 0                        
______________________________________________________________________________________________________________
input_2 (InputLayer)        (None, None, 93)                                 0                        
______________________________________________________________________________________________________________
lstm_1 (LSTM)               [(None, 256), (None, 256), (None, 256)]          334848         input_1[0][0]     
______________________________________________________________________________________________________________
lstm_2 (LSTM)               [(None, None, 256), (None, 256), (None, 256)]    358400         input_2[0][0]     
                                                                                            lstm_1[0][1]      
                                                                                            lstm_1[0][2]      
______________________________________________________________________________________________________________
dense_1 (Dense)             (None, None, 93)                                 23901          lstm_2[0][0]      
==============================================================================================================
Total params: 717,149
Trainable params: 717,149
Non-trainable params: 0
______________________________________________________________________________________________________________

——————————————————————

总目录

keras的example文件解析

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值