四、单词纠错神器-训练数据构建
4.1 encoder decoder的输入单词
主要有三个步骤:
- 将原始单词以随机的方式改为不正确的(input1)
- 将辅助输入结果增加起始符和末尾符(input2)
- 将最终的输出增加末尾符(target)
def generate(self, thresh=0.2):
"""
训练数据
1- 生成 input1 input2
2- embedding
"""
lines = self._load_data()
test_size = int(len(lines) * self.test_size)
input1_list, input2_list = [], []
input1_max_len, input2_max_len = 0, 0
for w in lines[:-test_size]:
if len(w) > 10:
# 将原始单词以随机的方式改为不正确的(input1)
input1_word = self.gen_gibberish(w, thresh=thresh)
# 将辅助输入结果增加起始符和末尾符(input2)
input2_word = f'\t{w}\n'
input1_list.append(input1_word)
input2_list.append(input2_word)
input1_max_len = max(input1_max_len, len(input1_word))
input2_max_len = max(input2_max_len, len(input2_word))
# 2- embedding
# - 将最终的输出增加末尾符(target)
return self.word_embedding(input1_list, input2_list, input1_max_len, input2_max_len)
4.2 将单词embedding
做成[sample_num, encode_len, char_len]
的张量。
【注】:
- sample_num:样本数
- encode_len:单词的长度 / 每个单词的字母的位置(取最长)
- char_len: 字符的个数
例如:cab
对于encode 的input1的单词Embedding之后如下。
encode_len: idx=0, char_len: idx=2 : c
encode_len: idx=1, char_len: idx=0 : a
encode_len: idx=2, char_len: idx=1 : b
def word_embedding(self, input1_list, input2_list, input1_max_len, input2_max_len):
"""
当没有提供embedding的方法的时候,
采用最简单的字母位置及出现则标记为1, 否则标记为0。 便于后面一个一个字母预测的时候抽取字母
"""
samples_count = len(input1_list)
input1_encode_data = np.zeros((samples_count, input1_max_len, len(self.char_set)), dtype='float64')
input2_decode_data = np.zeros((samples_count, input2_max_len, len(self.char_set)), dtype='float64')
target_data = np.zeros((samples_count, input2_max_len, len(self.char_set)), dtype='float64')
# 将矩阵填充上数据 某个字母出现一次则标记增加1
for num_idx, (inp1_w, inp2_w) in enumerate(zip(input1_list, input2_list)):
for w_idx, chr_tmp in enumerate(inp1_w):
input1_encode_data[num_idx, w_idx, self.char2int[chr_tmp]] = 1
for w_idx, chr_tmp in enumerate(inp2_w):
input2_decode_data[num_idx, w_idx, self.char2int[chr_tmp]] = 1
if w_idx > 0: # 预测起始符后的 - 将最终的输出增加末尾符(target)
target_data[num_idx, w_idx - 1, self.char2int[chr_tmp]] = 1
return input1_encode_data, input2_decode_data, target_data
五、模型训练
构建了数据集, 模型后。我们需要训练模型(建议采用GPU [笔者白嫖的Kaggle的GPU])
如果用本地的CPU跑,内存不够的时候, 想简单看下效果,可以将 epochs改成100, batch_size改小一点,即每次进入模型的时候样本数少一点。
g = GeneratData()
input1_encode_data, input2_decode_data, target_data = g.generate()
# lstm的cell个数, 输出的Dense的维度, encode的单词的长度(最后个维度), decode的单词的长度(最后个维度)
m = de_right_word_tf2(256, 39, 39, 39)
m.summary()
his_ = m.fit([tf.constant(input1_encode_data), tf.constant(input2_decode_data)], tf.constant(target_data),
epochs=500,
batch_size=256,
validation_split=0.2
)
"""
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_7 (InputLayer) [(None, None, 39)] 0
__________________________________________________________________________________________________
input_8 (InputLayer) [(None, None, 39)] 0
__________________________________________________________________________________________________
lstm_6 (LSTM) [(None, 256), (None, 303104 input_7[0][0]
__________________________________________________________________________________________________
lstm_7 (LSTM) [(None, None, 256), 303104 input_8[0][0]
lstm_6[0][1]
lstm_6[0][2]
__________________________________________________________________________________________________
dense_3 (Dense) (None, None, 39) 10023 lstm_7[0][0]
==================================================================================================
Total params: 616,231
Trainable params: 616,231
Non-trainable params: 0
__________________________________________________________________________________________________
>>>
"""
训练的损失收敛情况:
"""
91/91 [==============================] - 2s 21ms/step - loss: 0.0921 - val_loss: 0.3849
Epoch 497/500
91/91 [==============================] - 2s 20ms/step - loss: 0.0932 - val_loss: 0.3884
Epoch 498/500
91/91 [==============================] - 2s 20ms/step - loss: 0.0917 - val_loss: 0.3866
Epoch 499/500
91/91 [==============================] - 2s 21ms/step - loss: 0.0911 - val_loss: 0.3869
Epoch 500/500
91/91 [==============================] - 2s 20ms/step - loss: 0.0921 - val_loss: 0.3876
"""