Tensorflow + Mnist (两层CNN,两层全连接)

前面几天中断了好几天,装了个linux,搭建了一下深度学习环境。入坑tensorflow,算是目前相当方便的一个平台了。环境的搭建我有单独写了个博客。
我搭建的环境
ubuntu16.04LTS + python3.6+tensorflow1.2
我的硬件环境:
i7-4720HQ @2.60ghz*8 + 950m
直接上手tensorflow的入门教程,mnist手写字符的识别,tensorflow的官方文档写了一个手写字符识别的入门CNN网络,但是没有画出网络结构,相信对于初学者还是有点难以理解的。
我这里画了一个草图
这里写图片描述
总的来说就是两层卷积(第一层包括一个卷积(32个5×5的kernel)+一个池化,第二层包括一个卷积(64个5×5的kernel)+一个池化)两层全连接,前面3层的激活函数都是采用了relu:max(0,x),最后一层用softmax输出10类目标

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Jun 25 11:59:53 2017

@author: matthew
"""

import input_data
import tensorflow as tf
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
sess = tf.InteractiveSession()

# build softmax
x = tf.placeholder("float",shape = [None,784])
y_ = tf.placeholder("float",shape = [None,10])

#initialize weights and bias
def weight_variable(shape):
    initial = tf.truncated_normal(shape,stddev = 0.1)
    return tf.Variable(initial)
def bias_variable(shape):
    initial = tf.constant(0.1,shape = shape)
    return tf.Variable(initial)

#conv and pooling
def conv2d(x,w):
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding = 'SAME')
def max_pool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding = 'SAME')

#first layer,a conv layer and a pooling layer
#conv layer:[5,5,1,32](the size of kernel)
w_conv1 = weight_variable([5,5,1,32])
b_conv1 = bias_variable([32])

x_image = tf.reshape(x, [-1,28,28,1])#-1 denotes orignal size
h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

#second layer,a conv layer and a pooling layer
#kernal size[5,5,32,64]
w_conv2 = weight_variable([5,5,32,64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

#third layer,fclayer
#now after two pooling, the size of image is 7*7

w_fc1 = weight_variable([7*7*64,1024])
b_fc1 = bias_variable([1024])

#mat2vec
h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)

#add drop to fc layer
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

w_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])

y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, w_fc2) + b_fc2)

#cost func
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))
sess.run(tf.global_variables_initializer())

for i in range(20000):
    batch = mnist.train.next_batch(50)
    if i%100 == 0:
        train_accuracy = accuracy.eval(feed_dict = {x:batch[0],y_:batch[1],keep_prob:1.0})
        print ("step %d,training accuracy %g"%(i,train_accuracy))
    train_step.run(feed_dict={x:batch[0],y_:batch[1],keep_prob:0.5})

print ("test accuracy %g"%accuracy.eval(feed_dict = {x:mnist.test.images,y_:mnist.test.labels,keep_prob:1.0}))

这是tensorflow官方提供的mnist数据下载文件

# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import gzip
import os
import tempfile

import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

值得一提的是,drop_out技巧,用在全连接层,随机的减去全连接层当中的一些连接,加强网络范化能力,防止过拟合。
在我的电脑上跑了快一个小时,最终的范化误差再0.993。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值