摘要
本文使用纯 Python 和 PyTorch 对比实现 MaxPool 函数及其反向传播.
相关
原理和详细解释, 请参考文章 :
池化层MaxPool函数详解及反向传播的公式推导
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
正文
import torch
import numpy as np
class MaxPool2D:
def __init__(self, kernel_size=(2, 2), stride=2):
self.stride = stride
self.kernel_size = kernel_size
self.w_height = kernel_size[0]
self.w_width = kernel_size[1]
self.x = None
self.in_height = None
self.in_width = None
self.out_height = None
self.out_width = None
self.arg_max = None
def __call__(self, x):
self.x = x
self.in_height = np.shape(x)[0]
self.in_width = np.shape