四、优化方法

目录

4-1 交叉熵

4-2 Dropout:解决神经网络的过拟合问题

4-3 优化器的选择

4-4 优化MNIST手写数据集分类


 

4-1 交叉熵

     关于交叉熵的理解,可以参考这些文章:

      https://blog.csdn.net/chaipp0607/article/details/73392175

      https://blog.csdn.net/m_buddy/article/details/80224409

     https://blog.csdn.net/red_stone1/article/details/80735068

     在最理想的情况下,如果一个样本属于k,那么这个类别所对应的的输出节点的输出值应该为1,而其他节点的输出都为0,即[0,0,1,0,….0,0],这个数组也就是样本的Label,是神经网络最期望的输出结果,交叉熵就是用来判定实际的输出与期望的输出的接近程度!

在TensorFlow中实现softmax交叉熵:

loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))

4-2 Dropout:解决神经网络的过拟合问题

    Dropout的工作原理:是在神经网络的dropout层为每个神经元结点设置一个随机消除概率,对于保留下来的神经元,我们得到一个节点较少,规模较小的网络进行训练(通过程序理解就是调节keep_prob的大小来控制工作神经元的个数)。这也是一种正则化方法。

4-3 优化器的选择

    Tensorflow有很多优化器,针对不同的情况选择合适的优化器

  • GradientDescentOptimizer
  • AdagradOptimizer
  • AdagradDAOptimizer
  • MomentumOptimizer
  • AdamOptimizer
  • FtrlOptimizer
  • RMSPropOptimizer

4-4 优化MNIST手写数据集分类

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data


#导入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#如果本地没有数据集,此语句会自动下载到对应的文件夹位置,不过网速较慢,不建议
#每个批次的大小
batch_size = 100
#计算一共需要多少个批次
n_batch = mnist.train.num_examples // batch_size
#创建两个placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
keep_prob = tf.placeholder(tf.float32)
Ir = tf.Variable(0.01,dtype=tf.float32)

#创建一个简单的神经网络
W1 = tf.Variable(tf.truncated_normal([784,500],stddev=0.1))#如果优化程序可以降低神经元的个数(2000),本程序设置的比较大
b1 = tf.Variable(tf.zeros([500])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1) + b1)
L1_drop = tf.nn.dropout(L1,keep_prob)

W2 = tf.Variable(tf.truncated_normal([500,300],stddev=0.1))
b2 = tf.Variable(tf.zeros([300])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2) + b2)
L2_drop = tf.nn.dropout(L2,keep_prob)

W3 = tf.Variable(tf.truncated_normal([300,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3) + b3)

#交叉熵损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=ls=y,logits=prediction))

#使用AdamOptimizer训练
train_step = tf.train.AdamOptimizer(Ir).minimize(loss)
#结果存放在一个布尔类型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#对比预测结果的标签是否一致,一致为True,不同为False
#预测准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))#将布尔型转化为0.0-1.0之间的数值,True为1.0,False为0.0
#变量初始化
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(50):#可以修改迭代次数多次测试找到最佳结果
        learning_rate = sess.run(tf.assign(Ir,0.01 * (0.95 ** epoch)))#随着迭代次数的增加学习率降低
        for batch in range(n_batch):
            batch_x,batch_y = mnist.train.next_batch(batch_size)#
            sess.run(train_step,feed_dict={x:batch_x,y:batch_y,keep_prob:1})#调整keep_prob的大小来选择工作的神经元个数
            
        test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        print('Iter' + str(epoch) + ',Testing Accuaracy' + str(test_acc) + ',Lenrning_rate' + str(learning_rate))

训练结果:

Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
Iter0,Testing Accuaracy0.923,Lenrning_rate0.01
Iter1,Testing Accuaracy0.9139,Lenrning_rate0.0095
Iter2,Testing Accuaracy0.9059,Lenrning_rate0.009025
Iter3,Testing Accuaracy0.9213,Lenrning_rate0.00857375
Iter4,Testing Accuaracy0.9241,Lenrning_rate0.00814506
Iter5,Testing Accuaracy0.9251,Lenrning_rate0.00773781
Iter6,Testing Accuaracy0.935,Lenrning_rate0.00735092
Iter7,Testing Accuaracy0.9195,Lenrning_rate0.00698337
Iter8,Testing Accuaracy0.9337,Lenrning_rate0.0066342
Iter9,Testing Accuaracy0.9402,Lenrning_rate0.00630249
Iter10,Testing Accuaracy0.9376,Lenrning_rate0.00598737
Iter11,Testing Accuaracy0.9427,Lenrning_rate0.005688
Iter12,Testing Accuaracy0.9413,Lenrning_rate0.0054036
Iter13,Testing Accuaracy0.9392,Lenrning_rate0.00513342
Iter14,Testing Accuaracy0.9397,Lenrning_rate0.00487675
Iter15,Testing Accuaracy0.9462,Lenrning_rate0.00463291
Iter16,Testing Accuaracy0.9442,Lenrning_rate0.00440127
Iter17,Testing Accuaracy0.9428,Lenrning_rate0.0041812
Iter18,Testing Accuaracy0.9497,Lenrning_rate0.00397214
Iter19,Testing Accuaracy0.9537,Lenrning_rate0.00377354
Iter20,Testing Accuaracy0.9519,Lenrning_rate0.00358486
Iter21,Testing Accuaracy0.9495,Lenrning_rate0.00340562
Iter22,Testing Accuaracy0.9545,Lenrning_rate0.00323534
Iter23,Testing Accuaracy0.9527,Lenrning_rate0.00307357
Iter24,Testing Accuaracy0.9555,Lenrning_rate0.00291989
Iter25,Testing Accuaracy0.9541,Lenrning_rate0.0027739
Iter26,Testing Accuaracy0.9546,Lenrning_rate0.0026352
Iter27,Testing Accuaracy0.9551,Lenrning_rate0.00250344
Iter28,Testing Accuaracy0.958,Lenrning_rate0.00237827
Iter29,Testing Accuaracy0.9599,Lenrning_rate0.00225936
Iter30,Testing Accuaracy0.9552,Lenrning_rate0.00214639
Iter31,Testing Accuaracy0.9596,Lenrning_rate0.00203907
Iter32,Testing Accuaracy0.9596,Lenrning_rate0.00193711
Iter33,Testing Accuaracy0.9597,Lenrning_rate0.00184026
Iter34,Testing Accuaracy0.9611,Lenrning_rate0.00174825
Iter35,Testing Accuaracy0.9609,Lenrning_rate0.00166083
Iter36,Testing Accuaracy0.9606,Lenrning_rate0.00157779
Iter37,Testing Accuaracy0.9641,Lenrning_rate0.0014989
Iter38,Testing Accuaracy0.9617,Lenrning_rate0.00142396
Iter39,Testing Accuaracy0.9623,Lenrning_rate0.00135276
Iter40,Testing Accuaracy0.9626,Lenrning_rate0.00128512
Iter41,Testing Accuaracy0.9631,Lenrning_rate0.00122087
Iter42,Testing Accuaracy0.9643,Lenrning_rate0.00115982
Iter43,Testing Accuaracy0.963,Lenrning_rate0.00110183
Iter44,Testing Accuaracy0.9655,Lenrning_rate0.00104674
Iter45,Testing Accuaracy0.966,Lenrning_rate0.000994403
Iter46,Testing Accuaracy0.9656,Lenrning_rate0.000944682
Iter47,Testing Accuaracy0.9658,Lenrning_rate0.000897448
Iter48,Testing Accuaracy0.9658,Lenrning_rate0.000852576
Iter49,Testing Accuaracy0.9665,Lenrning_rate0.000809947

 

说明:1.关于tensorflow的代码是参考了b站练数成金的代码,链接地址:https://www.bilibili.com/video/av20542427/?p=1

           2.部分代码还参考了tensorflow中文社区的网站,以及tensorflow的官网(需要梯子)。

           tensorflow中文社区的网站:http://www.tensorfly.cn/

   MNIST_data手写数据集下载:链接:https://pan.baidu.com/s/1_PxLxxZ4YP7KfDzZh8vPFA 密码:nyrs

   更多Tensorflow资源下载,去github搜索Tensorflow下载更多demo

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值