tensorflow学习6:使用tensorflow mnist分类搭建卷积神经网络CNN

本次博客主要是是搭建自己的卷积神经网络,采用的数据集为常用的mnist手写数字数据集。

第一步准备数据

我们导入mnist数据集,因为在tf 中已经准备好了这个数据集,我们只需要在import的时候插入一下就可以。如果在在对应的文件夹下有这个数据集就不会下载,如果没有那么就会默认下载这个文件夹。

其中batch size 为每一次送入网络样本的个数。

lr 为训练的步长。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt

batch_size=50   
lr=0.001
mnist=input_data.read_data_sets('./mnist',one_hot=True)
test_x=mnist.test.image[:2000]
test_y=mnist.test.image[:2000]

tf_x=tf.placeholder(tf.float32,[None,28*28])/255.
image=tf.reshape(tf_x,[-1,28,28,1])
tf_y=tf.placeholder(tf.int32,[None,10])

第二步搭建我们的CNN网络

conv2d第一个是输入数据,filter是输出的维度,kernel是滑动窗口的大小,stride为每次移动的步长,padding是再图片的外面是否进行扩展。

pooling层控制着图片的下采样。

flat层的目的是把数据展成一行 这个样就可以送入后面的全连接层

#搭建CNN网络
conv1 = tf.layers.conv2d(inputs=image,filter=16,kernel_size=5,stride=1,padding='same'
                        activation=tf.nn.relu)  #(28,28,16)
pooling1=tf.layers.max_poolong2d(conv1,pool_size=2,strides=2) #(14,14,16)
conv2=tf.layers.conv2d(pooling1,filter=32,5,1,'same',tf.nn.relu) #(14,14,32)
pooling2=tf.layers.max_poolong2d(conv2,2,2) #(7,7,32)
flat=tf.reshape(pooling2,[-1,7*7*32]) # 1568列
output=tf.layers.dense(flat,10)  #输出十个类别

#loss
loss = tf.losses.softmax_cross_entropy(onehot_labels=tf_y, logits=output)           # compute cost
train_op = tf.train.AdamOptimizer(lr).minimize(loss)

#计算准确率
accuracy = tf.metrics.accuracy(          # return (acc, update_op), and create 2 local variables
    labels=tf.argmax(tf_y, axis=1), predictions=tf.argmax(output, axis=1),)[1]

第三步把定义好的运算在session中运行。

sess = tf.Session()
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_op
sess.run(init_op)     # initialize var in graph
for step in range(600):
    b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
    _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
    if step % 50 == 0:
        accuracy_, flat_representation = sess.run([accuracy, flat], {tf_x: test_x, tf_y: test_y})
        print('Step:', step, '| train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)

     

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值