CNN中做归一化用到的相关API(自己的小总结:tf.nn.moments()函数理解) 以及CNN中NHWC转NCHW的方法

Note1:CNN中NHWC转NCHW的方法:

比如卷积层输出的net形状为:[2, 3, 3, 4]
即:NHWC为[2, 3, 3, 4]
N:一个batch内图片的数量。
H:垂直高度方向的像素个数。
W:水平宽度方向的像素个数。
C:通道数

现为了做BN,想先将NHWC转为NCWH=[2, 4, 3, 3]
方法呢?可以使用TensorFlow中的tf.transpose函数实现!
n = tf.transpose(net, (0, 3, 2, 1)),,,其中第二个参数是转换后的张量中,原始张量的维度编号(原来是(0, 1, 2, 3))
得到n的形状是[2, 4, 3, 3]
即:NCWH为[2, 4, 3, 3]
下面举个例子:三维的,供参考

import tensorflow as tf
import numpy as np

x = [[[1,   2],
      [3,   4]],
     [[11,  22],
      [33,  44]],
     [[111, 222],
      [333, 444]]]  # x的shape为(3, 2, 2),通道数为3
y = tf.transpose(x, (1, 2, 0))  # 其中第二个参数是转换后的张量中,原始张量的维度编号。编号0原本在首位,现在处于末位。
with tf.Session() as sess:
    # print(y.eval())
    '''
    [[[  1  11 111]
      [  2  22 222]]

     [[  3  33 333]
     [  4  44 444]]]
    '''
    # print(x[0, :, :])
    # 出现报错:TypeError: list indices must be integers or slices, not tuple
    # 这是因为此时矩阵存储在列表(list)中,而列表中的每一个元素大小可能不同,因此不能直接取其某一列进行操作
    # 解决方案
    # 可以利用numpy.array函数将其转变为标准矩阵,再对其进行取某一列的操作:若下所示:
    # print(np.array(x)[0, :, :])
    '''
    取第一个维度首元素如下:
    [[1 2]
     [3 4]]
    '''
    # print(np.array(x)[:, :, 0])
    '''
    [[  1   3]
     [ 11  33]
     [111 333]]
    '''
    # print(y[0, :, :])  # Tensor("strided_slice:0", shape=(2, 3), dtype=int32)
    # print(y[0, :, :].eval())  # 等价于print(sess.run(y[0, :, :]))
    '''
    取y第一个维度首元素如下:
    [[  1  11 111]
     [  2  22 222]]
    '''
    # print(np.shape(y))  # (2, 2, 3)
    print(y[:, :, 0].eval())
    '''
    取y最后一个维度首元素如下:
    [[1 2]
     [3 4]]
    '''

    # 可见x[0, :, :] = y[:, :, 0]。张量已经由NCHW转换为NHWC格式。

Note2:tf.nn.moments()函数理解

import numpy as np
import tensorflow as tf
net = tf.constant(np.reshape(np.asarray(range(0, 72)), (2, 3, 3, 4)))
n = tf.transpose(net, (0, 3, 2, 1))
# BN中
m0, n0 = tf.nn.moments(n, axes=(0, 2, 3))
m1, v1 = tf.nn.moments(net, axes=(0, 1, 2))
m2, v2 = tf.nn.moments(net, axes=(0, 2, 1))
with tf.Session():
    '''查看一下net和n的值'''
    print(net.eval())
    print(10*'-')
    print(n.eval())
    
    '''查看一下transpose后值的情况'''
    print(net[:, :, :, 0].eval())
    print(10*'-')
    print(n[:, 0, :, :].eval())
    
    '''验证效果:发现transpose后,方便了计算,而且计算结果正确'''
    print(m1.eval())  # [34 35 36 37]
    print(m0.eval())  # [34 35 36 37]
    print(m2.eval())  # [34 35 36 37]

手写解释代码计算过程
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值