一、目的
在我们日常开发环境(pycharm / VSCode
)中都自带单词纠错的插件。由于近段时间准备重新梳理下NLP的知识,所以准备从这个单词纠错的插件着手,逐步构建出一个单词纠错神器。
里面主要会涉及输入输出数据集的构建,以及单词纠错网络的搭建。以及输入输出数据集构建的优化,网络框架的优化(会输出tensorflow2.6 以及 pytorch1.8 )。
二、单词纠错神器网络搭建思路
第一直觉是sequence to sequence
的模型, 输入一个拼写错误的单词,输出一个拼写正确的单词。
基于该直觉构建网络图如下:
网络简单搭建方式:
- 需要对输入input_1以及input_2进行编码
- input_2编码的时候需要加入特殊符号表示单词的开始和结束
- 损失函数确定:交叉熵
- 梯度优化两个备选: rmsprop 和 adam
- 网络搭建
三、单词纠错神器网络搭建
3.1 tensorflow 2.0及以上版本
from tensorflow.keras.layers import Dense, LSTM, Input
from tensorflow.keras import Model
def de_right_word_tf2(lstm_units, out_dims, encode_max_len, decode_max_len):
encoder_lstm = LSTM(lstm_units, return_state=True)
# 需要将各个隐层的结果作为下一层的输入时,选择设置 return_sequences=True
decoder_lstm = LSTM(lstm_units, return_state=True, return_sequences=True)
fc = Dense(out_dims, activation='softmax')
input_1 = Input(shape=(None, encode_max_len))
encode_out, encode_h, encode_c = encoder_lstm(input_1)
input_2 = Input(shape=(None, decode_max_len))
decode_out, decode_h, decode_c = decoder_lstm(input_2, initial_state=[encode_h, encode_c])
predict_out = fc(decode_out)
model = Model([input_1, input_2], predict_out)
model.compile(
optimizer='rmsprop',
loss=['categorical_crossentropy']
)
return model