CNN实现手写数字识别
#cnn : 1 卷积
# ABC
# A: 激励函数+矩阵 乘法加法
# A CNN : pool(激励函数+矩阵 卷积 加法)
# C:激励函数+矩阵 乘法加法(A-》B)
# C:激励函数+矩阵 乘法加法(A-》B) + softmax(矩阵 乘法加法)
# loss:tf.reduce_mean(tf.square(y-layer2))
# loss:code
#1 import
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
# 2 load data
mnist = input_data.read_data_sets('MNIST_data',one_hot = True)
# 3 input
imageInput = tf.placeholder(tf.float32,[None,784]) # 28*28
labeInput = tf.placeholder(tf.float32,[None,10]) # knn
# 4 data reshape
# [None,784]->M*28*28*1 2维D->4维D 28*28 wh 1 channel 灰度为1
imageInputReshape = tf.reshape(imageInput,[-1,28,28,1])
# 5 卷积 w0 : 卷积内核 5*5 out:32位 in:1
w0 = tf.Variable(tf.truncated_normal([5,5,1,32],stddev = 0.1)) # 权重矩阵 变量
b0 = tf.Variable(tf.constant(0.1,shape=[32])) # 偏移矩阵 变量 shape=[32]偏移矩阵最后一个维度与权重矩阵最后一个矩阵同
# 6 # layer1:激励函数 (0输出非线性)+ 卷积运算
# tf.nn.conv2d 卷积运算 tf.nn.relu 激励函数
# 输入数据 imageInputReshape : M*28*28*1 权重 w0:5,5,1,32 strides每次移动步长 padding='SAME'表明卷积核可以定留在图像边缘
layer1 = tf.nn.relu(tf.nn.conv2d(imageInputReshape,w0,strides=[1,1,1,1],padding='SAME')+b0)
# 数据大小 :M*28*28*32
# 由于数据太大,数据采样,池化层,max_pool(池化的数据选最大的) 采样 下采样会减小 数据量减少很多
# layer1 要池化的数据 举例(与ksize相除) M*28*28*32 =>(0相除结果) M*7*7*32
layer1_pool = tf.nn.max_pool(layer1,ksize=[1,4,4,1],strides=[1,4,4,1],padding='SAME')
# 例子[1 2 3 4]->[4]
# 7 输出层
# 股票预测中 layer2 out : 激励函数+乘加运算: softmax(激励函数 + 乘加运算)
# w1 二维 [7*7*32,1024]
w1 = tf.Variable(tf.truncated_normal([7*7*32,1024],stddev=0.1))
b1 = tf.Variable(tf.constant(0.1,shape=[1024]))
# 维度转化
h_reshape = tf.reshape(layer1_pool,[-1,7*7*32])# M*7*7*32 -> N,7*7*32
# 相乘的两个矩阵的维度 [N,7*7*32] [7*7*32,1024] = N*1024
h1 = tf.nn.relu(tf.matmul(h_reshape,w1)+b1)
# 7.1 softMax
w2 = tf.Variable(tf.truncated_normal([1024,10],stddev=0.1))
b2 = tf.Variable(tf.constant(0.1,shape=[10]))
# 预测数据
pred = tf.nn.softmax(tf.matmul(h1,w2)+b2)# h1: N*1024 w2: 1024*10 = N*10(输出结果)
# 误差 N*10( 数字分别出现的概率 )N1【0.1 0.2 0.4 0.1 0.2 。。。】第一个为1出现的概率0.1
# 标签 N*10 label中仅有一个是1 【0 0 0 0 1 0 0 0.。。】
# 计算误差 log(pred)减小预测结果的范围取对数,压缩数据范围
loss0 = labeInput*tf.log(pred) #让标签与当前预测进行相乘,得到的就是正确结果出现的概率,正确概率不断增大,得到正确结果
loss1 = 0
# 7.2 累加数据取均值
#原for m in range(0,500):
for m in range(0,100):# test 100
for n in range(0,10): # lable 标签十个维度的值进行累加
# loss1 = loss1 - loss0[m,n] 累加
loss1 = loss1 - loss0[m,n] # 由于要使用梯度下降法,所以用减号,得到累加取反结果
loss = loss1/100 # 500组数据
# 原 loss = loss1/500
# 8 train 训练减小误差 GradientDescentOptimizer(0.01) 梯度下降法每次缩小0.01 minimize(loss)尽可能缩小误差
# 完成神经网络搭建
train = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
# 9 run 检测搭建效果
with tf.Session() as sess:
# 所有变量初始化
sess.run(tf.global_variables_initializer())
for i in range(100):
# 遍历图片 每次读取500张
# 原images,labels = mnist.train.next_batch(500)
images,labels = mnist.train.next_batch(100)
# imageInput 图片输入 labeInput标签输入
sess.run(train,feed_dict={imageInput:images,labeInput:labels})
# 预测 pred_test得到的是0-9每个数出现的概率
pred_test = sess.run(pred,feed_dict={imageInput:mnist.test.images,labeInput:labels})
# 比较预测值和测试数据标签值最大值是否相等
acc = tf.equal(tf.arg_max(pred_test,1),tf.arg_max(mnist.test.labels,1))
# tf.float32概率转换为浮点数,reduce_mean完成均值计算
acc_float = tf.reduce_mean(tf.cast(acc,tf.float32))
# 当前的结果 imageInput labeInput
acc_result = sess.run(acc_float,feed_dict={imageInput:mnist.test.images,labeInput:mnist.test.labels})
print(acc_result)
- 运算需要时间,待心等待 我将测试数据进行修改,修改为100,运行结果100个
WARNING:tensorflow:From <ipython-input-1-11d6df98fb93>:14: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\Users\16603\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\Users\16603\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-images-idx3-ubyte.gz
WARNING:tensorflow:From C:\Users\16603\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\16603\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\Users\16603\Anaconda3\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From <ipython-input-1-11d6df98fb93>:79: arg_max (from tensorflow.python.ops.gen_math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.math.argmax` instead
0.1179
0.1131
0.0871
0.1395
0.1559
0.2124
0.244
0.2199
0.2329
0.27
0.3111
0.3332
0.3702
0.3982
0.4292
0.439
0.4431
0.4978
0.5278
0.5499
0.5393
0.5826
0.5889
0.5929
0.5972
0.5604
0.5988
0.6306
0.6259
0.6448
0.654
0.6713
0.6641
0.6992
0.6757
0.6956
0.69
0.7205
0.7126
0.7145
0.7268
0.7315
0.7338
0.7257
0.7387
0.718
0.7409
0.728
0.7368
0.7606
0.775
0.756
0.7638
0.7587
0.7907
0.7855
0.7767
0.7649
0.7867
0.7941
0.7721
0.7926
0.791
0.7699
0.794
0.7891
0.8091
0.7922
0.8074
0.8021
0.7991
0.7926
0.8153
0.8169
0.8152
0.822
0.8068
0.8184
0.8274
0.8272
0.8346
0.8337
0.8195
0.8364
0.8282
0.8363
0.8109
0.8304
0.8445
0.8364
0.8384
0.8407
0.8418
0.8368
0.8462
0.8309
0.8275
0.8468
0.8479
0.8426