【Tensorflow】Bi-LSTM文本分类

用于训练的计算图

#训练图
train_graph = tf.Graph()
with train_graph.as_default():
    #输入文本维度为[time_step,batch_size,embedding_size]
    encoder_inputs = tf.placeholder(shape=[None,None,input_size], dtype=tf.float32, name='encoder_inputs')
    #文本标签
    text_label = tf.placeholder(shape=(None), dtype=tf.int32)
    #每个文本的序列长度
    text_length = tf.placeholder(shape=(None), dtype=tf.int32)
   
    #前向cell和后向cell,分别加了dropout
    encoder_fw_cell = tf.contrib.rnn.LSTMCell(256)
    encoder_fw_cell = tf.nn.rnn_cell.DropoutWrapper(encoder_fw_cell, output_keep_prob=0.7)
    encoder_bw_cell = tf.contrib.rnn.LSTMCell(256)
    encoder_bw_cell = tf.nn.rnn_cell.DropoutWrapper(encoder_bw_cell, output_keep_prob=0.7)
    #双向LSTM,输出outputs为两个cell的output
    encoder_outputs, encoder_final_state = tf.nn.bidirectional_dynamic_rnn(
        encoder_fw_cell, encoder_bw_cell, encoder_inputs,
        sequence_length=text_length,
        dtype=tf.float32, time_major=True,
    )
    #将两个cell的outputs进行拼接
    encoder_outputs = tf.concat(encoder_outputs,2)
    #全连接层
    fc1 = tf.contrib.layers.linear(encoder_outputs[-1], 256)
    logits_ = tf.contrib.layers.linear(fc1, class_num)
    
    prediction = tf.argmax(logits_, 1)
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.one_hot(text_label, depth=class_num, dtype=tf.float32),
        logits=logits_,
    )
    
    correct_prediction = tf.equal(prediction, tf.argmax(tf.one_hot(text_label, depth=class_num, dtype=tf.float32), 1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    loss = tf.reduce_mean(cross_entropy)
    train_op = tf.train.AdamOptimizer().minimize(loss)
    
    saver = tf.train.Saver()

阅读更多
换一批

没有更多推荐了,返回首页