MNIST手写数字识别
tensorflow基本概念
# coding=utf-8
import tensorflow as tf
import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
train_img = mnist.train.images
train_label = mnist.train.labels
test_img = mnist.train.images
test_label = mnist.train.labels
print ("mnist loaded")
'''
saver = tf.train.Saver()
with tf.Session() as sess:
save_path = saver.save(sess,"/home/aitab/PycharmProjects/mnist2")
print("model saved in file:",save_path)
'''
print (train_img.shape)
print (train_label.shape)
print (test_img.shape)
print (test_label.shape)
#print (trainimg)
print (train_label[0])
x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
# model
actv = tf.nn.softmax(tf.matmul(x,W) + b)
# loss function
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
# optimezer
learning_rate = 0.01
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# prediction
pred = tf.equal(tf.arg_max(actv,1),tf.arg_max(y,1))
# accuracy
accr = tf.reduce_mean(tf.cast(pred,"float"))
# initalizer
init = tf.global_variables_initializer()
train_epochs = 50
batch_size = 100
step = 5
sess = tf.Session()
sess.run(init)
for epoch in range(train_epochs):
avg_cost = 0.
batch_num = int(mnist.train.num_examples/batch_size)
for i in range(batch_num):
# 每次丢入一个batch
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
feeds = {x:batch_xs,y:batch_ys}
avg_cost += sess.run(cost,feed_dict=feeds)/batch_num
# display
if epoch % step == 0:
feeds_train = {x:batch_xs,y:batch_ys}
feeds_test = {x:mnist.test.images,y:mnist.test.labels}
train_acc = sess.run(accr,feed_dict=feeds_train)
test_acc = sess.run(accr,feed_dict=feeds_test)
print("Epoch: %03d/%03d cost:%.9f train_acc:%.3f test_acc:%.3f"% (epoch,train_epochs,test_acc,test_acc))
print ("DONE")
OUT:
Epoch: #03d/000 cost:50.000000000 train_acc:0.852 test_acc:0.852
Epoch: #03d/005 cost:50.000000000 train_acc:0.895 test_acc:0.895
Epoch: #03d/010 cost:50.000000000 train_acc:0.905 test_acc:0.905
Epoch: #03d/015 cost:50.000000000 train_acc:0.909 test_acc:0.909
Epoch: #03d/020 cost:50.000000000 train_acc:0.913 test_acc:0.913
Epoch: #03d/025 cost:50.000000000 train_acc:0.914 test_acc:0.914
Epoch: #03d/030 cost:50.000000000 train_acc:0.915 test_acc:0.915
Epoch: #03d/035 cost:50.000000000 train_acc:0.917 test_acc:0.917
Epoch: #03d/040 cost:50.000000000 train_acc:0.918 test_acc:0.918
Epoch: #03d/045 cost:50.000000000 train_acc:0.918 test_acc:0.918
DONE