手撕池化层(numpy)
池化层定义
池化是缩小高、长方向上的空间的运算
除了max池化
,还有average池化
,一般来说,池化的窗口大小会和步幅设定成一样的值
池化层没有要学习的参数,输入数据和输出数据的通道数不会发生变化,计算是按通道独立进行。对输入数据发生微小变卦具有鲁棒性(健壮)
numpy手撕池化
池化过程和卷积过程相同,会涉及到im2col
和 col2im
的操作,由于卷积部分我已经说过了,所以池化部分就不说了,想了解的可以看手撕卷积层
import os, sys
import collections
import numpy as np
class Pooling:
def __init__(self, pool_h, pool_w, stride=1, padding=0):
self.pool_h = pool_h
self.pool_w = pool_w
self.stride = stride
self.padding = padding
self.x = None
self.arg_max = None
def foward(self, x):
N, C, H, W = x.shape
out_h = int(1 + (H - self.pool_h) / self.stride)
out_w = int(1 + (W - self.pool_w) / self.stride)
col = self.im2col(x, self.pool_h, self.pool_w, self.stride, self.padding)
col = col.reshape(-1, self.pool_h * self.pool_w)
arg_max = np.argmax(col, axis=1)
out = np.max(col, axis=1)
out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
self.x = x
self.arg_max = arg_max
return out
def backward(self, dout):
dout = dout.transpose(0, 2, 3, 1)
pool_size = self.pool_h * self.pool_w
dmax = np.zeros(dout.size, pool_size)
dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
dmax = dmax.reshape(dout.shape + (pool_size, ))
dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
dx = self.col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.padding)
return dx
def im2col(self, input_data, conv_h, conv_w, stride=1, padding=0):
"""
:param input_data: N x C x H x W
:param conv_h: 卷积核的高
:param conv_w: 卷积核的长
:param stride: 步幅
:param padding: 填充
:return: col 2维数据
"""
N, C, H, W = input_data.shape
out_h = (H + 2 * padding - conv_h) // stride + 1
out_w = (W + 2 * padding - conv_w) // stride + 1
img = np.pad(input_data, [(0, 0), (0, 0), (padding, padding), (padding, padding)], "constant")
col = np.zeros(N, C, conv_h, conv_w, out_h, out_w)
for y in range(conv_h):
y_max = y + stride * out_h
for x in range(conv_w):
x_max = x + stride * out_w
col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]
col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N * out_h * out_w, -1)
return col
def col2im(self, col, input_shape, conv_h, conv_w, stride=1, padding=0):
"""
:param col:
:param input_shape: 输入数据形式 例如:(10, 1, 28, 28)
:param conv_h:
:param conv_w:
:param stride:
:param padding:
:return:
"""
N, C, H, W = input_shape
out_h = (H + 2 * padding - conv_h) // stride + 1
out_w = (W + 2 * padding - conv_w) // stride + 1
col = np.zeros(N, out_h, out_w, C, conv_h, conv_w).transpose(0, 3, 4, 5, 1, 2)
img = np.zeros((N, C, H + 2 * padding + stride - 1, W + 2 * padding + stride - 1))
for y in range(conv_h):
y_max = y + stride * out_h
for x in range(conv_w):
x_max = x + stride * out_h
img[:, :, y:y_max:stride, x:x_max:stride] += col[:, :, y, x, :, :]
return img[:, :, padding:H + padding, padding:W + padding]