【tensorflow】tf.nn.conv2d的使用

官方文档地址

接口如下

tf.nn.conv2d(
input, filters, strides, padding, data_format=‘NHWC’, dilations=None,
name=None
)

input数据

input为入参,其shape必须为4维的,其中每一维度含义如下

  • N:Batch Number
  • H:Height
  • W:Width
  • C:Num of Channles

但是顺序由data_format来指定,一般为"NHWC"或者"NCHW"

keras中的data_format一般为"NHWC"

filter数据

filter则为神经网络训练出来的kernel值,它为4维。
[filter_height, filter_width, in_channels, out_channels]

注:out_channels即为filter的个数

算法

  1. 压平filter为二维matrix,称为张量A
    压平的目的shape为
    [filter_height * filter_width * in_channels, output_channels]

  2. 将Input中元素转换为虚拟tensor,称为张量B
    虚拟tensor的shape为
    [batch, out_height, out_width, filter_height * filter_width * in_channels]

  3. 然后将上述两者做tenserdot运算,得到张量
    C = B   t e n s o r d o t   A C=B\ tensordot\ A C=B tensordot A

计算结果shape的公式

Batch个数和输入通道数是不变的,主要是输出的高与宽的值
o u t _ w i d t h = ( i n p u t _ w i d t h − f i l t e r _ w i d t h ) / s t r i d e _ w i d t h + 1 o u t _ h e i g h t = ( i n p u t _ h e i g h t − f i l t e r _ h e i g h t ) / s t r i d e _ h e i g h t + 1 out\_width = (input\_width-filter\_width)/stride\_width+1 \newline out\_height = (input\_height-filter\_height)/stride\_height+1 out_width=(input_widthfilter_width)/stride_width+1out_height=(input_heightfilter_height)/stride_height+1
所以输出shape为**(N, OUT_H, OUT_W, C)**


input shap到虚拟tensor shape举例

假设input data为 [1,2,3,4,5,6,7,8,9],shape为[1,3,3,1]

我们把数据下标标记好,那么所有数据如下
(1,3,3,1)shape下的数据下标
同样,假设filter为(1,2,2,1),使用该filter在input data上滑动截取[2,2],效果如下
filter作用效果

根据上面介绍的shape计算公式,虚拟tensor的shape应该为(1,2,2,4),一共四个16个数据,我们穷举出所有的数值来看一下:

(0,0,x,x)

(0,0,x,x)

(0,1,x,x)

(0,1,x,x)

所以一共生成了16个数据,shape为(1,2,2,4)


代码样例

# Batch 1
# Height 3
# Width 3
# Channels 1
input_shape=(1,3,3,1)
x_in=np.linspace(1,9,num=9).reshape(input_shape)

# Height 2
# Width 2
# in channels 1
# out channels 1
kernel_in = np.array([1, 2, 3, 4]).reshape((2,2,1,1))
print('kernel shape: ', kernel_in.shape)

x = tf.constant(x_in, dtype=tf.float32)
kernel = tf.constant(kernel_in, dtype=tf.float32)
y = tf.nn.conv2d(x, kernel, strides=[1, 1, 1, 1], padding='VALID')

print("============================")
print(y)

结果

  • 2
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值