深度学习_用LSTM构建单词纠错神器(2)

本文详细介绍了如何通过encoder-decoder模型进行单词纠错,包括数据生成过程(随机修改单词并添加边界标识)、embedding的处理方法,以及使用LSTM进行模型训练。展示了关键步骤和参数设置,以及训练过程中损失函数的变化情况。
摘要由CSDN通过智能技术生成

四、单词纠错神器-训练数据构建

4.1 encoder decoder的输入单词

主要有三个步骤:

  • 将原始单词以随机的方式改为不正确的(input1)
  • 将辅助输入结果增加起始符和末尾符(input2)
  • 将最终的输出增加末尾符(target)
  def generate(self, thresh=0.2):
        """
        训练数据
        1- 生成 input1 input2
        2- embedding
        """
        lines = self._load_data()
        test_size = int(len(lines) * self.test_size)
        input1_list, input2_list = [], []
        input1_max_len, input2_max_len = 0, 0
        for w in lines[:-test_size]:
            if len(w) > 10:
            	# 将原始单词以随机的方式改为不正确的(input1) 
                input1_word = self.gen_gibberish(w, thresh=thresh)
            	# 将辅助输入结果增加起始符和末尾符(input2)
                input2_word = f'\t{w}\n'
                
                input1_list.append(input1_word)
                input2_list.append(input2_word)

                input1_max_len = max(input1_max_len, len(input1_word))
                input2_max_len = max(input2_max_len, len(input2_word))

        # 2- embedding
        # - 将最终的输出增加末尾符(target)
        return self.word_embedding(input1_list, input2_list, input1_max_len, input2_max_len)

4.2 将单词embedding

做成[sample_num, encode_len, char_len]的张量。
【注】:

  • sample_num:样本数
  • encode_len:单词的长度 / 每个单词的字母的位置(取最长)
  • char_len: 字符的个数

例如:cab

对于encode 的input1的单词Embedding之后如下。
encode_len: idx=0, char_len: idx=2 : c
encode_len: idx=1, char_len: idx=0 : a
encode_len: idx=2, char_len: idx=1 : b
在这里插入图片描述

    def word_embedding(self, input1_list, input2_list, input1_max_len, input2_max_len):
        """
        当没有提供embedding的方法的时候,
        采用最简单的字母位置及出现则标记为1, 否则标记为0。 便于后面一个一个字母预测的时候抽取字母
        """
        samples_count = len(input1_list)
        input1_encode_data = np.zeros((samples_count, input1_max_len, len(self.char_set)), dtype='float64')
        input2_decode_data = np.zeros((samples_count, input2_max_len, len(self.char_set)), dtype='float64')
        target_data = np.zeros((samples_count, input2_max_len, len(self.char_set)), dtype='float64')

        # 将矩阵填充上数据 某个字母出现一次则标记增加1
        for num_idx, (inp1_w, inp2_w) in enumerate(zip(input1_list, input2_list)):
            for w_idx, chr_tmp in enumerate(inp1_w):
                input1_encode_data[num_idx, w_idx, self.char2int[chr_tmp]] = 1

            for w_idx, chr_tmp in enumerate(inp2_w):
                input2_decode_data[num_idx, w_idx, self.char2int[chr_tmp]] = 1
                if w_idx > 0: # 预测起始符后的 - 将最终的输出增加末尾符(target)
                    target_data[num_idx, w_idx - 1, self.char2int[chr_tmp]] = 1
        return input1_encode_data, input2_decode_data, target_data

五、模型训练

构建了数据集, 模型后。我们需要训练模型(建议采用GPU [笔者白嫖的Kaggle的GPU])
如果用本地的CPU跑,内存不够的时候, 想简单看下效果,可以将 epochs改成100, batch_size改小一点,即每次进入模型的时候样本数少一点。

g = GeneratData()
input1_encode_data, input2_decode_data, target_data = g.generate()

# lstm的cell个数, 输出的Dense的维度, encode的单词的长度(最后个维度), decode的单词的长度(最后个维度)
m = de_right_word_tf2(256, 39, 39, 39)
m.summary()
his_ = m.fit([tf.constant(input1_encode_data), tf.constant(input2_decode_data)], tf.constant(target_data),
	epochs=500,
    batch_size=256,
    validation_split=0.2
)
"""
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_7 (InputLayer)            [(None, None, 39)]   0
__________________________________________________________________________________________________
input_8 (InputLayer)            [(None, None, 39)]   0
__________________________________________________________________________________________________
lstm_6 (LSTM)                   [(None, 256), (None, 303104      input_7[0][0]
__________________________________________________________________________________________________
lstm_7 (LSTM)                   [(None, None, 256),  303104      input_8[0][0]
                                                                 lstm_6[0][1]
                                                                 lstm_6[0][2]
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, None, 39)     10023       lstm_7[0][0]
==================================================================================================
Total params: 616,231
Trainable params: 616,231
Non-trainable params: 0
__________________________________________________________________________________________________
>>>
"""

训练的损失收敛情况:
在这里插入图片描述

"""
91/91 [==============================] - 2s 21ms/step - loss: 0.0921 - val_loss: 0.3849
Epoch 497/500
91/91 [==============================] - 2s 20ms/step - loss: 0.0932 - val_loss: 0.3884
Epoch 498/500
91/91 [==============================] - 2s 20ms/step - loss: 0.0917 - val_loss: 0.3866
Epoch 499/500
91/91 [==============================] - 2s 21ms/step - loss: 0.0911 - val_loss: 0.3869
Epoch 500/500
91/91 [==============================] - 2s 20ms/step - loss: 0.0921 - val_loss: 0.3876
"""
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Scc_hy

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值