LeNet:最早用于数字识别的CNN
输入层:32*32===》C1===》S2(平均池化)===》C3===》S4===》F5===》F6===》F7(输出层)
TensorFlow实现
数据集准备:
首先是网络结构如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/10/22 14:15
# @Author : HJH
# @File : LeNet.py
# @Software: PyCharm
"""
# code is far away from bugs with the god animal protecting
I love animals. They taste delicious.
┏┓ ┏┓
┏┛┻━━━┛┻┓
┃ ☃ ┃
┃ ┳┛ ┗┳ ┃
┃ ┻ ┃
┗━┓ ┏━┛
┃ ┗━━━┓
┃ 神兽保佑 ┣┓
┃ 永无BUG! ┏┛
┗┓┓┏━┳┓┏┛
┃┫┫ ┃┫┫
┗┻┛ ┗┻┛
"""
import tensorflow as tf
class LeNet(object):
def __init__(self, x, num_class, keep_drop, regulation_rate):
self.x = x
self.num_class = num_class
self.keep_drop = keep_drop
self.regulation_rate = regulation_rate
self.__create__()
def __create__(self):
# 网上大多数博客的实现方式
# conv1 = conv(self.x, 5, 5, 32, 1, 1, name='conv1', padding='SAME')
# pool1 = pool(conv1, 2, 2, 2, 2, name='pool1', padding='SAME')
#
# conv2 = conv(pool1, 5, 5, 64, 1, 1, name='conv2', padding='SAME')
# pool2 = pool(conv2, 2, 2, 2, 2, name='pool2', padding='SAME')
#
# pool2_shape = pool2.get_shape().as_list()
# pool2_len = pool2_shape[1] * pool2_shape[2] * pool2_shape[3]
# reshaped = tf.reshape(pool2, [pool2_shape[0], pool2_len])
# fc3 = fc(reshaped, pool2_len, 512, self.regulation_rate, name='fc3')
# drop3 = dropout(fc3, self.keep_drop)
#
# self.fc4 = fc(drop3, 512, self.num_class, self.regulation_rate, name='fc4', relu=False)
# 论文中的实现方式
conv1 = conv(self.x, 5, 5, 6, 1, 1, name='conv1', padding='SAME')
pool1 = pool(conv1, 2, 2, 2, 2, name='pool1', padding='SAME')
conv2 = conv(pool1, 5, 5, 16, 1, 1, name='conv2', padding='SAME')
pool2 = pool(conv2, 2, 2, 2, 2, name='pool2', padding='SAME')
conv3 = conv(pool2, 5, 5, 120, 1, 1, name='conv3', padding='SAME')
conv3_shape = conv3.get_shape().as_list()
conv3_len = conv3_shape[1] * conv3_shape[2] * conv3_shape[3]
reshaped = tf.reshape(conv3, [conv3_shape[0], conv3_len])
fc4 = fc(reshaped, conv3_len, 84, self.regulation_rate, name='fc3')
self.fc5 = fc(fc4, 84, self.num_class, self.regulation_rate, name='fc4', relu=False)
def conv(x, filter_height, filter_width, num_filter, stride_y, stride_x, name, padding='SAME'):
input_channels = int(x.get_shape()[-1])
with tf.variable_scope(name) as scope:
weights = tf.get_variable("weights", shape=[filter_height, filter_width, input_channels, num_filter],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases = tf.get_variable("biases", shape=[num_filter], initializer=tf.constant_initializer(0.0))
conv = tf.nn.conv2d(x, weights, padding=padding, strides=[1, stride_y, stride_x, 1])
bias = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv))
relu = tf.nn.relu(bias, name=scope.name)
return relu
def pool(x, filter_height, filter_width, stride_y, stride_x, name, padding='SAME'):
return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], strides=[1, stride_y, stride_x, 1],
padding=padding, name=name)
def fc(x, num_in, num_out, regulation_rate, name, relu=True):
with tf.variable_scope(name) as scope:
weights = tf.get_variable("weights", shape=[num_in, num_out],
initializer=tf.truncated_normal_initializer(stddev=0.1))
biases = tf.get_variable("biases", shape=[num_out], initializer=tf.constant_initializer(0.1))
tf.add_to_collection("losses", tf.contrib.layers.l2_regularizer(regulation_rate)(weights))
if relu is True:
act = tf.nn.xw_plus_b(x, weights, biases)
relu = tf.nn.relu(act, name=scope.name)
return relu
else:
act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name)
return act
def dropout(x, keep_drop):
return tf.nn.dropout(x, keep_prob=keep_drop)
训练过程如下:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 2018/10/24 11:04
# @Author : HJH
# @File : train.py
# @Software: PyCharm
"""
# code is far away from bugs with the god animal protecting
I love animals. They taste delicious.
┏┓ ┏┓
┏┛┻━━━┛┻┓
┃ ☃ ┃
┃ ┳┛ ┗┳ ┃
┃ ┻ ┃
┗━┓ ┏━┛
┃ ┗━━━┓
┃ 神兽保佑 ┣┓
┃ 永无BUG! ┏┛
┗┓┓┏━┳┓┏┛
┃┫┫ ┃┫┫
┗┻┛ ┗┻┛
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import LeNet
import numpy as np
import os
batch_size = 100
'''网上一些代码的准确率总是在10%左右徘徊,这是因为学习率设置的太高(如果没记错的话,这些代码的学习率为0.8)'''
learning_rate_base = 0.01
learning_rate_decay = 0.99
epochs = 30000
moving_average_decay = 0.99
regulation_rate = 0.0001
num_classes = 10
droup_rate = 0.5
model_save_path = './model/tensorboard/'
model_name = 'lenet.ckpt'
summary_save_path = './model/checkpoints/'
def train(data):
# 数据集占位符
x = tf.placeholder(tf.float32, shape=[batch_size, 28, 28, 1])
y = tf.placeholder(tf.float32, shape=[None, num_classes])
keep_prob = tf.placeholder(tf.float32)
# 网络模型设计
model = LeNet.LeNet(x, num_classes, keep_prob, regulation_rate)
global_step = tf.Variable(0, trainable=False)
score = model.fc5
# 损失函数定义
variable_average = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
variable_average_op = variable_average.apply(tf.trainable_variables())
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=score, labels=y)
cross_entropy_mean = tf.reduce_mean(cross_entropy)
with tf.name_scope('loss'):
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
# 学习率
with tf.name_scope('learning_rate'):
learning_rate = tf.train.exponential_decay(learning_rate_base, global_step,
data.train.num_examples / batch_size, learning_rate_decay)
# 梯度下降
with tf.name_scope('train'):
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
with tf.control_dependencies([train_step, variable_average_op]):
train_op = tf.no_op(name='train')
# 训练准确率
with tf.name_scope('train_accuracy'):
correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# 添加到tensorboard
tf.summary.scalar('loss', loss)
tf.summary.scalar('learning_rate', learning_rate)
tf.summary.scalar('accuracy', accuracy)
merged_summary = tf.summary.merge_all()
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
writer = tf.summary.FileWriter(summary_save_path, sess.graph)
for i in range(epochs):
xs, ys = data.train.next_batch(batch_size)
xs = np.reshape(xs, [batch_size, 28, 28, 1])
_, loss_value, step, accuracy_value = sess.run([train_op, loss, global_step, accuracy],
feed_dict={x: xs, y: ys, keep_prob: droup_rate})
print("After %d training steps,loss is %g,accuracy is %g" % (step, loss_value, accuracy_value))
if i % 1000 == 0:
summary_value = sess.run(merged_summary, feed_dict={x: xs, y: ys, keep_prob: 1.})
writer.add_summary(summary_value, step)
saver.save(sess, os.path.join(model_save_path, model_name), global_step)
if __name__ == '__main__':
my_mnist = input_data.read_data_sets('./MNIST/', one_hot=True)
train(my_mnist)
训练结果如下所示: