下面实现了一个四层全连接网络,利用mnist数据集对这个网络进行训练,代码中keep_prob参数设置为[1.,1.,1.],说明并未进行dropout操作,网络结构如下
1 layer 2 layer 3 layer 4 layer
___w 784 x 1024 ____ w 1024 x 512 ___ w 512 x 256 ___ 3 256 x 10
| | | |
input_x --| ----| ---| ---|
|___b 1024 |____ b 512 |___ b 256 |___ b 10
代码如下
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# 在tensorflow的log日志等级如下:
# - 0:显示所有日志(默认等级)
# - 1:显示info、warning和error日志
# - 2:显示warning和error信息
# - 3:显示error日志信息
os.environ["TF_CPP_MIN_LOG_LEVEL"] = '2'
# MNIST_data是个.zip压缩文件 保存在跟py文件同样路径下
mnist_set = input_data.read_data_sets('./MNIST_data', one_hot=True)
batch_size = 32 # batch_size 批大小,根据自己的gup内存大小设置
batch_num = mnist_set.train.num_examples // batch_size # 每批样本数量 // 在python语法中表示整除 向下取整
# 定义两个占位符
x = tf.placeholder(tf.float32,[None,784])
y_data = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32,[3])
# build a network
# layer 1 hidden : 1024
weight1 = tf.Variable(initial_value=tf.truncated_normal([784,1024], stddev=0.1))
bias1 = tf.Variable(initial_value=tf.zeros([1024]))
l1 = tf.nn.tanh(tf.matmul(x,weight1) + bias1)
l1 = tf.nn.dropout(l1,keep_prob=keep_prob[0])
# layer 2 hidden : 512
weight2 = tf.Variable(initial_value=tf.truncated_normal([1024,512], stddev=0.1))
bias2 = tf.Variable(initial_value=tf.zeros([512]))
l2 = tf.nn.tanh(tf.matmul(l1,weight2) + bias2)
l2 = tf.nn.dropout(l2,keep_prob=keep_prob[1])
# layer 3 hidden : 256
weight3 = tf.Variable(initial_value=tf.truncated_normal([512,256], stddev=0.1))
bias3 = tf.Variable(initial_value=tf.zeros([256]))
l3 = tf.nn.tanh(tf.matmul(l2,weight3) + bias3)
l3 = tf.nn.dropout(l3,keep_prob=keep_prob[2])
# output layer : 10
weight4 = tf.Variable(initial_value=tf.truncated_normal([256,10], stddev=0.1))
bias4 = tf.Variable(initial_value=tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(l3,weight4) + bias4)
# train stage
# loss
# loss = tf.reduce_mean(tf.square(y_data - prediction))
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_data, logits=prediction))
# optimize
# train_op = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss)
# test stage
# tf.argmax 返回一维中最大值的位置
# 比如
# a = [[2,3,2,0]
# [3,5,2,9]
# [2,7,6,2]]
# tf.argmax(a,0) 0 表示按列处理
# 结果[1,2,2,1]
# tf.argmax(a,1) 1 表示按行处理
# 结果[1,3,1]
correct_prediction = tf.equal(tf.argmax(prediction,1),tf.argmax(y_data,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for epoch in range(1,21):
for batch in range(batch_num):
batch_x, batch_y = mnist_set.train.next_batch(batch_size)
sess.run(train_op,feed_dict={x:batch_x,y_data:batch_y,keep_prob:[1.,1.,1.]})
test_acc = sess.run(accuracy,feed_dict=
{x:mnist_set.test.images,y_data:mnist_set.test.labels,keep_prob:[1.,1.,1.]})
train_acc = sess.run(accuracy,feed_dict=
{x:mnist_set.train.images,y_data:mnist_set.train.labels,keep_prob:[1.,1.,1.]})
print('第%d代:测试正确率为:%s, 训练正确率率:%s' % (epoch, test_acc, train_acc))
结果如下
Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz
第1代:测试正确率为:0.9442, 训练正确率率:0.9482727
第2代:测试正确率为:0.9529, 训练正确率率:0.9585818
第3代:测试正确率为:0.9499, 训练正确率率:0.95518184
第4代:测试正确率为:0.9581, 训练正确率率:0.9646
第5代:测试正确率为:0.9601, 训练正确率率:0.96958184
第6代:测试正确率为:0.9585, 训练正确率率:0.9652
第7代:测试正确率为:0.9641, 训练正确率率:0.9732909
第8代:测试正确率为:0.9618, 训练正确率率:0.97163635
第9代:测试正确率为:0.9646, 训练正确率率:0.9757636
第10代:测试正确率为:0.9642, 训练正确率率:0.9754909
第11代:测试正确率为:0.9658, 训练正确率率:0.9774
第12代:测试正确率为:0.9657, 训练正确率率:0.97810906
第13代:测试正确率为:0.9659, 训练正确率率:0.9764909
第14代:测试正确率为:0.9699, 训练正确率率:0.9815636
第15代:测试正确率为:0.9532, 训练正确率率:0.9654
第16代:测试正确率为:0.9682, 训练正确率率:0.9774727
第17代:测试正确率为:0.9705, 训练正确率率:0.98105454
第18代:测试正确率为:0.9635, 训练正确率率:0.97763634
第19代:测试正确率为:0.9657, 训练正确率率:0.98025453
第20代:测试正确率为:0.9694, 训练正确率率:0.98094547
从以上结果中我们可以看到测试集正确率明显比训练集正确率低0.01百分点,由于我们的网络结构比较简单,如果利用像vgg-16进行比较的话得到的结果偏差会更大。这说明网络模型出现了过拟合现象。大家都知道抑制过拟合的方法主要有一下三种:
1 数据增强data augmentation
2 loss上加正则化项
3 dropout
现在我就利用dropout这个方法来抑制过拟合,只需把keep_prob参数修改为[.6,.6,.5],注意,在计算测试集和训练集准确率时,keep_prob的参数要都设置为[1.,1.,1.]。设置后重新训练,结果如下
Extracting ./MNIST_data\train-images-idx3-ubyte.gz
Extracting ./MNIST_data\train-labels-idx1-ubyte.gz
Extracting ./MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ./MNIST_data\t10k-labels-idx1-ubyte.gz
第1代:测试正确率为:0.9236, 训练正确率:0.9226
第2代:测试正确率为:0.9374, 训练正确率:0.9390727
第3代:测试正确率为:0.9438, 训练正确率:0.9433636
第4代:测试正确率为:0.9502, 训练正确率:0.94914544
第5代:测试正确率为:0.9468, 训练正确率:0.94985455
第6代:测试正确率为:0.951, 训练正确率:0.9518727
第7代:测试正确率为:0.9513, 训练正确率:0.95469093
第8代:测试正确率为:0.9528, 训练正确率:0.95454544
第9代:测试正确率为:0.9503, 训练正确率:0.9557818
第10代:测试正确率为:0.9521, 训练正确率:0.95485455
第11代:测试正确率为:0.9564, 训练正确率:0.9591454
第12代:测试正确率为:0.9535, 训练正确率:0.95845455
第13代:测试正确率为:0.9547, 训练正确率:0.9584
第14代:测试正确率为:0.9562, 训练正确率:0.9604545
第15代:测试正确率为:0.9558, 训练正确率:0.95878184
第16代:测试正确率为:0.9552, 训练正确率:0.95896363
第17代:测试正确率为:0.9593, 训练正确率:0.9622
第18代:测试正确率为:0.9568, 训练正确率:0.9621636
第19代:测试正确率为:0.9589, 训练正确率:0.9616909
第20代:测试正确率为:0.9584, 训练正确率:0.9609454
从结果可以看出,测试集和训练集准确率之差较之前的有所减小,说明dropout对抑制过拟合具有一定作用