1.绪
query纠错是搜索中的重要模块,用户经常会在有意无意中输入一些错误的query。并且中文语境出现的错误,错法往往千奇百怪:有输入法联想错误导致输入其他同音词,由此出现的搭配不当;有发音不准导致拼音输错;还有形近字、几乎约定俗成的错字等等,很难有成熟的规律一网打尽。
2.方案
基于seq2seq的端到端纠错模型:
输入:原始query,经过embedding化之后进入双向encode层
输出:encode表示经过attention操作,再输入decode层,最后解码出对应的纠错query。
编解码的过程示例:
3.源码
import tensorflow as tf
class seq2seq(object):
def __init__(self, FLAGS):
self.vocab_size = FLAGS.vocab_size
self.embedding_dim = FLAGS.embedding_dim
self.rnn_hidden_size = FLAGS.rnn_hidden_size
self.lr = FLAGS.lr
self.use_dropoutwrapper = FLAGS.use_dropoutwrapper
self.beam_search_width = FLAGS.beam_search_width
self._init_placeholder()
self._init_embeddings()
self._init_encoder()
self._init_decoder()
self._init_optimizer()
def _init_placeholder(self):
self.seq_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='seq_inputs')
self.seq_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='seq_targets')
self.keep_prob = tf.placeholder(tf.float32, name="keep_prob")
self.batch_size = tf.shape(self.seq_inputs)[0]
used = tf.sign(tf.abs(self.seq_inputs))
length = tf.reduce_sum(used, reduction_indices=1)
self.seq_inputs_length