卷积神经网络分类mnist手写体数字

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import  input_data
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
import matplotlib.pyplot as plt

class Net:
    def __init__(self):
        self.x = tf.placeholder(tf.float32,[None,28,28,1])
        self.y = tf.placeholder(tf.float32,[None,10])
        self.conv1_w = tf.Variable(tf.random_normal([3,3,1,16],dtype=tf.float32,stddev=0.1))
        self.conv1_b = tf.Variable(tf.zeros([16]))
        self.conv2_w = tf.Variable(tf.random_normal([3,3,16,32],dtype=tf.float32,stddev=0.1))
        self.conv2_b = tf.Variable(tf.zeros([32]))
        self.w1 = tf.Variable(tf.random_normal([7*7*32,128],stddev=0.1))
        self.b1 = tf.Variable(tf.zeros([128]))
        self.w2 = tf.Variable(tf.random_normal([128,10],stddev=0.1))
        self.b2 = tf.Variable(tf.zeros([10]))
    def forward(self):
        self.conv1 = tf.nn.relu(tf.nn.conv2d(self.x,self.conv1_w,strides=[1,1,1,1],padding='SAME')+self.conv1_b)
        self.pool1 = tf.nn.max_pool(self.conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        self.conv2 = tf.nn.relu(tf.nn.conv2d(self.pool1,self.conv2_w,strides=[1,1,1,1],padding='SAME')+self.conv2_b)
        self.pool2 = tf.nn.max_pool(self.conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        self.flat = tf.reshape(self.pool2,[-1,7*7*32])
        self.y1 = tf.nn.relu(tf.matmul(self.flat,self.w1)+self.b1)
        self.y2 = tf.nn.softmax(tf.matmul(self.y1,self.w2)+self.b2)
    def backward(self):
        self.loss = tf.reduce_mean((self.y2-self.y)**2)
        self.opt = tf.train.AdamOptimizer().minimize(self.loss)
        self.prediction_corect = tf.equal(tf.argmax(self.y2,1),tf.argmax(self.y,1))#比较预测值和真实值是否相等
        self.rst = tf.cast(self.prediction_corect,'float')#将布尔值转化为float类型
        self.accuracy = tf.reduce_mean(self.rst)#求出平均值表示精度(百分数)

if __name__ == '__main__':
    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        a = []
        b = []
        c = []
        for i in range(1000):
            a.append(i)
            x,y = mnist.train.next_batch(100)
            x = x.reshape([100,28,28,1])
            loss,acc,_ = sess.run([net.loss,net.accuracy,net.opt],feed_dict={net.x:x,net.y:y})
            b.append(acc)
            c.append(loss)
            if i%10 == 0:
                plt.subplot(1,2,1)#生成1行两列的子图显示在第一个子图
                plt.plot(a,b)
                plt.title('accuracy rate')
                plt.subplot(1,2,2)#生成1行两列的子图显示在第二个子图
                plt.plot(a,c)
                plt.title('loss')
                plt.pause(0.0001)
            print(loss,acc)

 

利用tensorflow实现的卷积神经网络来进行MNIST手写数字图像的分类。 #导入numpy模块 import numpy as np #导入tensorflow模块,程序使用tensorflow来实现卷积神经网络 import tensorflow as tf #下载mnist数据集,并从mnist_data目录中读取数据 from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('mnist_data',one_hot=True) #(1)这里的“mnist_data” 是和当前文件相同目录下的一个文件夹。自己先手工建立这个文件夹,然后从https://yann.lecun.com/exdb/mnist/ 下载所需的4个文件(即该网址中第三段“Four files are available on this site:”后面的四个文件),并放到目录MNIST_data下即可。 #(2)MNIST数据集是手写数字字符的数据集。每个样本都是一张28*28像素的灰度手写数字图片。 #(3)one_hot表示独热编码,其值被设为true。在分类问题的数据集标注时,如何不采用独热编码的方式, 类别通常就是一个符号而已,比如说是9。但如果采用独热编码的方式,则每个类表示为一个列表list,共计有10个数值,但只有一个为1,其余均为0。例如,“9”的独热编码可以为[00000 00001]. #定义输入数据x和输出y的形状。函数tf.placeholder的目的是定义输入,可以理解为采用占位符进行占位。 #None这个位置的参数在这里被用于表示样本的个数,而由于样本个数此时具体是多少还无法确定,所以这设为None。而每个输入样本的特征数目是确定的,即为28*28。 input_x = tf.placeholder(tf.float32,[None,28*28])/255 #因为每个像素的取值范围是 0~255 output_y = tf.placeholder(tf.int32,[None,10]) #10表示10个类别 #输入层的输入数据input_x被reshape成四维数据,其中第一维的数据代表了图片数量 input_x_images = tf.reshape(input_x,[-1,28,28,1]) test_x = mnist.test.images[:3000] #读取测试集图片的特征,读取3000个图片 test_y = mnist.test.labels[:3000] #读取测试集图片的标签。就是这3000个图片所对应的标签
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值