data = tf.reshape(x,[-1, 16, 16])
#lstmUnits = 64
lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=keep_prob)
value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32) #value.shape: [batch_size,step,lstmUnits]
#weight = tf.Variable(tf.truncated_normal([lstmUnits, K]))
weight = tf.Variable(tf.truncated_normal([lstmUnits, n_output])/np.sqrt(n_output))
bias = tf.Variable(tf.constant(0.1, shape=[n_output]))
value = tf.transpose(value, [1, 0, 2])#对value进行value = tf.transpose(value, [1, 0, 2])操作后得到的value.shapeshape为[step,batch_size,lstmUnits]
#取最终的结果值
last = tf.gather(value, int(value.get_shape()[0]) - 1) #其中value.get_shape()[0]) - 1找到value经过transpose后的最后一个分片
pred = tf.nn.relu6(tf.matmul(last, weight) + bias) / 6
y_label = tf.argmax(pred, axis =1)
corr = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accu = tf.reduce_mean(tf.cast(corr, tf.float32))
tf.contrib.rnn.BasicLSTMCell()中间程序解释
最新推荐文章于 2020-03-19 12:05:48 发布