# -*- coding: utf-8 -*-
"""
Created on Wed Mar 21 10:33:32 2018
@author: kxq
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
##this data
mnist=input_data.read_data_sets('MNIST_data',one_hot=True)
##初始化参数
lr=0.001 ##学习率
training_iters=100000 ##循环次数
batch_size=128
display_step=10
n_inputs=28 ##MNIST input size(shape=28*28)
n_step=28 ##time step 总共有28行
n_hidden_units=128 ##隐藏层单元
n_classes=10 ##MNIST data classes(0-9)
##tf grid input
x=tf.placeholder(tf.float32,[None,n_step,n_inputs])
y=tf.placeholder(tf.float32,[None,n_classes])
##define weight
weight={'in':tf.Variable(tf.random_normal([n_inputs,n_hidden_units])),
'out':tf.Variable(tf.random_normal([n_hidden_units,n_classes]))
}
bias={'in':tf.Variable(tf.constant(0.1,shape=[n_hidden_units,])),
'out':tf.Variable(tf.constant(0.1,shape=[n_classes,]))
}
def RNN(X,weight,bias):
##hidden layer for input to cell
##X(128batch,28step,28inputs)
##==>(128*28,28inputs)三维转二维数据
X=tf.reshape(X,[-1,n_inputs])
##=>(128batch*28step,128hidden)二维数据
X_in=tf.matmul(X,weight['in'])+bias['in']
##==>(128batch,28shep,128hidden)三维数据
X_in=tf.reshape(X_in,[-1,n_step,n_hidden_units])
##cell
lstm_cell=tf.nn.rnn_cell.BasicLSTMCell(n_hidden_units,forget_bias=1.0,state_is_tuple=True)
_init_state=lstm_cell.zero_state(batch_size,dtype=tf.float32)
outputs,states=tf.nn.dynamic_rnn(lstm_cell,X_in,initial_state=_init_state,time_major=False)
##hidden layer for final outputs
##this states[1]=output[-1]
results=tf.matmul(states[1],weight['out'])+bias['out']
##or
# outputs=tf.unpack(tf.transpose(outputs,[1,0,2]))
# resuluts=tf.matmul(outputs[-1],weight['out'])+bias['out']
return results
prediction=RNN(x,weight,bias)
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
train_op=tf.train.AdamOptimizer(lr).minimize(cost)
correct_pred=tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_pred,tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
step=0
while step*batch_size<training_iters:
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
batch_xs=batch_xs.reshape([batch_size,n_step,n_inputs])
sess.run([train_op],feed_dict={x:batch_xs,y:batch_ys})
if step % 20==0:
print(sess.run(accuracy,feed_dict={x:batch_xs,y:batch_ys}))
step+=1
重复运行过程中,会出现报错,
Variable rnn/basic_lstm_cell/kernel already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
重新reset命令窗口,清除内存就行