Tensorflow_dropout解决过拟合问题

原创 2018年04月17日 10:01:59
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集;对数据集分batch并计算总共有多少batch
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
batch_size = 100
n_batch = mnist.train.num_examples // batch_size

#定义两个placeholder, keep_prob控制有多少个网络节点用来训练模型
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob=tf.placeholder(tf.float32)

#创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
b1 = tf.Variable(tf.zeros([2000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob) 

W2 = tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
b2 = tf.Variable(tf.zeros([2000])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob) 

W3 = tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
b3 = tf.Variable(tf.zeros([1000])+0.1)
L3 = tf.nn.tanh(tf.matmul(L2_drop,W3)+b3)
L3_drop = tf.nn.dropout(L3,keep_prob) 

W4 = tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
b4 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L3_drop,W4)+b4)

#交叉熵代价函数并用梯度下降法进行训练
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化变量
init = tf.global_variables_initializer()

#结果存放在一个布尔型列表中,并计算准确率 argmax()返回一维向量中最大值所在的位置
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(31):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7}) #0.7的网络节点训练
        #keep_prob:1.0 表示所有的网络节点用来测试
        test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(test_acc) +",Training Accuracy " + str(train_acc))

正常情况下,用训练集来测试和用测试集来测试的差异并不是太大,但是这个试验中差异却很大,这就是过拟合导致的。

过拟合一般出现在这样的情况下:待训练的网络结构复杂,这就会使参数太多,但是训练数据不足,从而出现过拟合。

可以通过增加数据集的方法,或者dropout方法解决。

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/lansetiankong2104/article/details/79970537

过拟合问题及解决办法

1 什么是过拟合? 一般提及到过拟合就是说在训练集上模型表现很好,但是在测试集上效果很差,即模型的泛化能力不行。过拟合是模型训练过程中参数拟合的问题,由于训练数据本身有采样误差,拟合模型参数时这些采样...
  • liubo187
  • liubo187
  • 2017-08-11 10:52:35
  • 389

过拟合问题以及解决方法

一.所谓过拟合问题 为了得到一致假设而使假设变得过度复杂称为过拟合。一个过配的模型试图连误差(噪音)都去解释(而实际上噪音又是不需要解释的),导致繁华能力较差,显然过犹不及了。 A model (...
  • ff19910203
  • ff19910203
  • 2015-09-20 21:28:50
  • 1979

过拟合及其解决方法

过拟合是在训练数据上拟合效果好,但在测试数据上效果比较差。我们通过偏频派与贝叶斯派分别解析为什么加入正则化项能避免过拟合现象。...
  • lin360580306
  • lin360580306
  • 2016-04-24 15:08:48
  • 11193

机器学习中的过拟合问题以及解决方案

笔者希望该笔记能够记录每个机器学习算法的过拟合问题。 过拟合问题举例 右图在训练数据上拟合完美,但是预测第11个时候, 左图虽然拟合不完全,但是更合理;右图的-953,误差极大。 ...
  • sinat_26917383
  • sinat_26917383
  • 2016-06-08 20:24:18
  • 7323

过拟合问题,通常会考虑两种途径来解决:a) 减少特征的数量:b) 正则化.

http://52opencourse.com/133/coursera%E5%85%AC%E5%BC%80%E8%AF%BE%E7%AC%94%E8%AE%B0-%E6%96%AF%E5%9D%A6...
  • hzw05103020
  • hzw05103020
  • 2015-11-18 10:50:48
  • 950

机器学习中过拟合问题分析及解决方法

机器学习中过拟合问题分析及解决方法表现:在训练集上的误差特别小,在测试集上的误差特别大。 原因:模型过于复杂,过分拟合数据噪声和outliers(离群值). 解决方法: 1、正则化。模型中添加先...
  • qq_23617681
  • qq_23617681
  • 2016-05-20 22:31:00
  • 824

机器学习中防止过拟合的处理方法

在进行数据挖掘或者机器学习模型建立的时候,因为在统计学习中,假设数据满足独立同分布,即当前已产生的数据可以对未来的数据进行推测与模拟,因此都是使用历史数据建立模型,即使用已经产生的数据去训练,然后使用...
  • heyongluoyao8
  • heyongluoyao8
  • 2015-10-26 20:58:12
  • 84241

过拟合问题简述

关于过度拟合的概念:给定一个假设空间H,一个假设h∈H,如果存在其他的假设h’∈H,使得在训练样例上h的错误率比h‘小,但在整个实例分布上h’的错误率比h小,那么就说假设h过度拟合训练数据 过滤拟合可...
  • mm_bit
  • mm_bit
  • 2015-07-22 09:51:12
  • 485

过拟合的原因+处理方法

过拟合的原因 1. 我们得到的模型g 太复杂。f很小,g 太大,会过拟合 2. 原本的模型(目标函数) f 太复杂 。g达不到f的形式,也会产生过拟合。模型f太复杂,其实也是一种噪声。 3...
  • MosBest
  • MosBest
  • 2016-08-11 23:26:45
  • 7092

处理过拟合问题-Regularization

数学中的Regularization是为了解决overfitting问题而引入的一种方法。所谓overfitting就是在一些数学模型中由于过于复杂,有太多的观测参数,以至于一点点微小的误差都回产生巨...
  • Real_Myth
  • Real_Myth
  • 2017-02-04 13:38:14
  • 1512
收藏助手
不良信息举报
您举报文章:Tensorflow_dropout解决过拟合问题
举报原因:
原因补充:

(最多只允许输入30个字)