2021SC@SDUSC
文章目录
本周在上一周对Encoder-Decoder模型的数据集和Encoder的构建部分学习的基础上进一步对 Decoder的构建部分展开了学习,整理笔记和代码实操记录如下。
一、Decoder模型构建
在Decoder端,主要完成:
- 对target数据进行处理
- 构造Decoder
1.target数据处理
target数据有两个作用:
- 在训练过程中,需要将target序列作为输入传给Decoder端RNN的每个阶段,而不是使用前一阶段预测输出,以此来提升模型的准确性。
- 需要用target数据来计算模型的loss。
首先需要对target端的数据进行预处理。将target中的序列作为输入给Decoder端的RNN时,序列中的最后一个字母(或单词)实际上是无用的。
如图,右边的Decoder端,可以看到target序列是[< go >, W, X, Y, Z, < eos >],其中< go >,W,X,Y,Z是每个时间序列上输入给RNN的内容。在此处,< eos >并没有作为输入传递给RNN,因此需要将target中的最后一个字符去掉,同时还需要在前面添加< go >标识,告诉模型这代表一个句子的开始。
使用**tf.strided_slice()**来进行这一步处理。
def process_decoder_input(data, vocab_to_int, batch_size):
'''
补充<GO>并移除最后一个字符
'''
ending = tf.strided_slice(data,[0, 0],[batch_size, -1], [1, 1])
decoder_input = tf.concat([tf.fill([batch_size, 1], vocab_to_int['<GO>']), ending], 1)
return decoder_input
其中tf.fill(dims, v