基于tensorflow使用RNN识别手写数字,注释比较详细。
另外我使用同样的代码训练多位验证码,每次预测的结果都是同样一组数字,无法成功训练出可以使用的模型,有了解相关内容的同学可以交流一下。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib import rnn
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '3'
# 定义网络
def rnn_network(X, W, b, nsteps, diminput, dimhidden):
X1 = tf.transpose(X, [1, 0, 2])
X2 = tf.reshape(X1, [-1, diminput])
H_1 = tf.matmul(X2, W["h1"]) + b["b1"]
H_1 = tf.split(H_1, nsteps, 0)
# 设置神经单元
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
LSTM_O, LSTM_S = rnn.static_rnn(lstm_cell, H_1, dtype=tf.float32)
output = tf.matmu