以2个3通道的分辨率为4*4的图片说明:
x = np.arange(96).reshape(2, 3, 4, 4)
其结果为:
array(
[
[
[
[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]
],
[
[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]
],
[
[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]
]
],
[
[
[48, 49, 50, 51],
[52, 53, 54, 55],
[56, 57, 58, 59],
[60, 61, 62, 63]
],
[
[64, 65, 66, 67],
[68, 69, 70, 71],
[72, 73, 74, 75],
[76, 77, 78, 79]
],
[
[80, 81, 82, 83],
[84, 85, 86, 87],
[88, 89, 90, 91],
[92, 93, 94, 95]
]
]
])
设pool_h=pool_w=2, stride=2,pad=0, 则可得:out_h=out_w=2。
通过函数im2col将4维的图片信号转化为2维矩阵:
col = im2col(x, pool_h=2, pool_w=2, stride=2, pad=0)
array([[ 0., 1., 4., 5., 16., 17., 20., 21., 32., 33., 36., 37.],
[ 2., 3., 6., 7., 18., 19., 22., 23., 34., 35., 38., 39.],
[ 8., 9., 12., 13., 24., 25., 28., 29., 40., 41., 44., 45.],
[10., 11., 14., 15., 26., 27., 30., 31., 42., 43., 46., 47.],
[48., 49., 52., 53., 64., 65., 68., 69., 80., 81., 84., 85.],
[50., 51., 54., 55., 66., 67., 70., 71., 82., 83., 86., 87.],
[56., 57., 60., 61., 72., 73., 76., 77., 88., 89., 92., 93.],
[58., 59., 62., 63., 74., 75., 78., 79., 90., 91., 94., 95.]])
由于是池化,所以需要按照不同通道拆分:
col = col.reshape(-1, pool_h*pool_w)
其中0,1,4,5是第1张图片第1通道第1次被卷积核扫描的数据。16,17,20,21是第1张图片第2通道第1次被卷积核扫描的数据。依次类推。
array([[ 0., 1., 4., 5.],
[16., 17., 20., 21.],
[32., 33., 36., 37.],
[ 2., 3., 6., 7.],
[18., 19., 22., 23.],
[34., 35., 38., 39.],
[ 8., 9., 12., 13.],
[24., 25., 28., 29.],
[40., 41., 44., 45.],
[10., 11., 14., 15.],
[26., 27., 30., 31.],
[42., 43., 46., 47.],
[48., 49., 52., 53.],
[64., 65., 68., 69.],
[80., 81., 84., 85.],
[50., 51., 54., 55.],
[66., 67., 70., 71.],
[82., 83., 86., 87.],
[56., 57., 60., 61.],
[72., 73., 76., 77.],
[88., 89., 92., 93.],
[58., 59., 62., 63.],
[74., 75., 78., 79.],
[90., 91., 94., 95.]])
按照最大值求池化:
out =np.max(col, axis=1)
其结果为:
array([ 5., 21., 37., 7., 23., 39., 13., 29., 45., 15., 31., 47., 53.,
69., 85., 55., 71., 87., 61., 77., 93., 63., 79., 95.])
通过转置求结果是其难点:
out = out.reshape(2, out_h=2, out_w=2, C=3).transpose(0, 3, 1, 2)
转置之前其结果为:
array([
[
[[ 5., 21., 37.],
[ 7., 23., 39.]],
[[13., 29., 45.],
[15., 31., 47.]]
],
[
[[53., 69., 85.],
[55., 71., 87.]],
[[61., 77., 93.],
[63., 79., 95.]]
]
])
其shape为(2,2,2,3)。主要是理解其维度与轴方向是一一对应的。即第0个维度对应axis=0, 第1个维度对应axis=1, 第2个维度对应axis=2, 第3个维度对应axis=3。对上述数组来说即axis=0的方向是(5,53),axis=1的方向是(5,13),axis=2的方向是(5,7),axis=3的方向是(5,21,37)。
以此来理解transpose(0,3,1,2) 即axis=3与axis=1的方向的数据对调。其结果为:
array([
[
[[ 5., 13.],
[ 7., 15.]],
[[21., 29.],
[23., 31.]],
[[37., 45.],
[39., 47.]]
],
[
[[53., 61.],
[55., 63.]],
[[69., 77.],
[71., 79.]],
[[85., 93.],
[87., 95.]]
]
])
然后再将即axis=2(5,7)与axis=3(5,13)的方向的数据对调得到最终结果:
array([
[
[[ 5., 7.],
[13., 15.]],
[[21., 23.],
[29., 31.]],
[[37., 39.],
[45., 47.]]
],
[
[[53., 55.],
[61., 63.]],
[[69., 71.],
[77., 79.]],
[[85., 87.],
[93., 95.]]
]
])