目录
数据集
从官网 下载手写数字的数据集,然后自行查找代码将其转化为bmp格式,将训练集0-9的数字分别放入0-9的九个文件夹中。
代码
定义三层神经网络,输入层784个,隐藏层30个,输出层10个。
import tensorflow as tf
import numpy as np
import os
from PIL import Image
# 加载训练数据
x_data_train= []
y_lable_train = []
for i in range(0,10):
for f in os.listdir('mninst/training-set/%s'%i):
if f.endswith('.bmp'):
im=Image.open('mnist/training-set/%s/%s'%(i,f))
mtr=np.array(im)
s=mtr.reshape(1,784)
for j in range(0,784):
if s[0][j]!=0:
s[0][j]=1
c=[0,0,0,0,0,0,0,0,0,0]
c[i]=1
x_data_train.append(s[0])
y_lable_train.append(c)
print('train data load success!')
# 读入测试数据
x_data_test = []
x_data_num = []
for i in range(0,10):
count = 0
for f in os.listdir('mnist/testing_set/%s'%i):
if f.endswith( '.bmp'):
im=Image.open('mnist/testing_set/%s/%s'%(i,f))
mrt=np.array(im)
s=mtr.reshape(1,784)
for j in range(0,784):
if s[0][j]!=0:
s[0][j]=1;
x_data_test.append(s[0]);
count+=1;
x_data_num.append(count)
print("test data load success!")
# 定义神经网络
def Layer(input,inSize,outSize,phi=None):
with tf.name_scope('later'):
with tf.name_scope('weight'):
w = tf.Variable(tf.random_normal([inSize,outSize]))
with tf.name_scope('basis'):
b = tf.Variable(tf.zeros(1,outSize)+0.1)
with tf.name_scope('z'):
z = tf.multiply(input,w)+b
if phi is None:
ouput = z
else:
output = phi(z)
return output
x_data_train = np.array(x_data_train)
y_lable_train = np.array(y_lable_train)
x_data_test = np.array(x_data_test)
with tf.name_scope('input'):
xs = tf.placeholder(tf.float32,[None,784])
ys = tf.placeholder(tf.float32,[None,10])
l1 = Layer(xs,784,30,phi=tf.nn.relu)
l2 = Layer(l1,30,10,phi=None)
with tf.name_scope('loss'):
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(l2,ys))
with tf.name_scope('train'):
train = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
#迭代20000次
for i in range(0,20000):
sess.run(train,feed_dict={xs:x_data_train,ys:y_lable_train})
print('step %s loss is %s'%(i,sess.run(loss,feed_dict={xs:x_data_train,ys:y_lable_train})))
writer = tf.summary.FileWriter("logs_test/",sess.graph)
#测试网络模型
x_test = x_data_test[0:4,:]
np.set_printoptions(threshold='nan')
result = sess.run(l2,feed_dict={xs:x_data_test})
print(result)
#验证准确率
flag = 0
for i in range(0,10):
num = x_data_num[i]
x_test = x_data_test[flag:(flag+num),:]
result=sess.run(l2,feed_dict={xs:x_test})
corr_num=0
for j in range(num):
if np.argmax(result[j])==i:
corr_num+=1
rate = float(corr_num)/float(num)
print("Number %s has %s, correct number is %s , accuracy is %s"%(i,num,corr_num,rate))
flag +=1