tensorflow中bn运算

bn的原理:https://blog.csdn.net/sunjinshengli/article/details/74037208
为了搞清楚bn的整个过程,我们来做个实验:
1 搭建最简单的一个bn层网络,保存网络结构:

import tensorflow as tf
import numpy as np
import cv2 
import random

batchsize = 50
ImgSize = 128

x = tf.placeholder(tf.float32, [None,8],name='input')

is_training = False

out = tf.layers.batch_normalization(x, training=is_training)

Testx = np.random.random((1,8))

from tensorflow.python.framework import graph_util

import os
pb_dir =  "./pb_dir"
if not os.path.exists(pb_dir):
    os.makedirs(pb_dir)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['batch_normalization/batchnorm/add_1'])
    with tf.gfile.FastGFile(pb_dir+'/BNTest.pb', mode='wb') as f:        
        f.write(constant_graph.SerializeToString())

2 读取模型结构

import tensorflow as tf
import numpy as np
import cv2 
import random

from tensorflow.python.platform import gfile

Testx = np.random.random((1,8))
for i in range (8):
    Testx[0][i] = 0.1*i*i
print(Testx)

with tf.gfile.FastGFile('pb_dir/BNTest.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='')

with tf.Session() as sess:
    writer = tf.summary.FileWriter("logs/", sess.graph)

在这里插入图片描述
可以看到实际前向传播的bn运算其实很简单,和公式也能够明显的一一对应
在这里插入图片描述
3 手动计算bn结果

with tf.Session() as sess:
    writer = tf.summary.FileWriter("logs/", sess.graph)

    beta = sess.run('batch_normalization/batchnorm/sub:0')
    Mm = sess.run('batch_normalization/moving_mean:0')
    gamma = sess.run('batch_normalization/batchnorm/mul:0')
    Mv = sess.run('batch_normalization/moving_variance:0')
    y = sess.run('batch_normalization/batchnorm/add/y:0')
    
    See1 = sess.run('batch_normalization/batchnorm/mul:0')
    See2 = sess.run('batch_normalization/batchnorm/mul_1:0',feed_dict={'input:0':Testx})
    See3 = sess.run('batch_normalization/batchnorm/add_1:0',feed_dict={'input:0':Testx})
    print(See3)


temp = y+Mv
temp = np.sqrt(temp)
temp = gamma/temp
temp2 = temp
for i in range(Testx.shape[1]):
    temp[i] = temp[i]*Testx[0][i]
temp2 = temp2*Mm
temp2 = beta-temp2
temp = temp+temp2
print(temp)

最后输出:
[[0. 0.09995004 0.39980015 0.89955026 1.5992006 2.498751
3.598201 4.8975515 ]]
[0. 0.0999001 0.3996004 0.8991009 1.5984015 2.4975023 3.5964036
4.895105 ]
结果略有差异,看上去是rsqrt精度有点差别

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值