以MNIST为例讲解如何代码实现模块化搭建神经网络"八股"

 

一  前向传播部分之forward.py

def forward(x, regularizer):

    w=

    b=

    y=

    return y

def get_weights(shape,regularizer):

def get_bias(shape):

二 反向传播部分之backward.py

def backward(mnist):

    x =

    y_=

    y  =

    global_step=

    loss =

    <正则化, 指数衰减学习率 ,滑动平均>

    train_step =

    实例化saver

     with tf.Session() as sess:

    初始化

     for i in range(STEPS):

         sess.run(train_step, feed_dict={x: , y_:})

         if i % 轮数 == 0:
             print

             saver.save( )

三 其他算法的加入

(1) 如果损失函数loss 含正则化regularization

则需要在backward.py中加入 

  ce = if.nn.sparse_softmax_cross_entropy_with_logits( = logits = y, labels = tf.argmax(y_, 1))

  cem = tf.reduce_mean(ce)

  loss = cem + tf.add_n (tf.get_collection( 'losses'))

在forward.py中加入

   if regularizer != None: tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))

(2) 加入指数衰减学习率 lerning_rate

需要在backward.py中加入

learning_rate = tf.train.exponential_decay(

LEARNING_RATE_BASE,

global_step,

LEARNING_RATE_STEP,

LEARNING_RATE_DECAY,

staircase =True)

(3) 加入滑动平均 ema

需要在backward.py中加入

ema = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)

ema_op = ema.apply(tf.trainable_variables())

with tf.control_dependencies([train_step, ema_op]):

        train_op = tf.no_op(name = 'train')

四 模型测试之test.py

def test(mnist)

    with tf.Graph() as g:

    定义 x    y_   y

     实例化可还原滑动平均值的saver

     计算正确率

     while True:

         with tf.Session() as sess:

              加载ckpt模型 ckpt = tf.train.get_checkpoint_state(存储路径)

              如果已有ckpt模型则恢复 if ckpt and ckpt.model_checkpoint_path:

              恢复会话 caver.restore(sess,, ckpt.model_checkpoint_path)

              恢复轮数global_step = ckpt.model_checkpoint_path.split( '/')[-1].split('-')[-1]

               计算准确率 accuracy_score = sess.run(accuracy, feed_dict =

                                      {x: mnist.test.images, y_: mnist.test.lavels})

               打印提示 print (" After %s training step(s) , test accuracy  = %g" % (global_step, accuracy_score))

             如果没有模型 else:
           给出提示 print ('NO checkpoint file found')

            return

def main():

    mnist = input_data.read_data_sets("./data/", one_hot= True)

    test(mnist)

if __name__ == '__main__':

    main()

 

 

 

 

 

 

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值