参考资料
链接:https://www.cnblogs.com/wzyuan/p/9880342.html
上面是一个关于pytorch对ResNet的官方实现的解读,作者已经写得很清楚啦,但是由于我是一个代码小白,所以对一些细节上的处理不太明白,这里花点功夫把它们搞清楚~
细节
-
downsample参数的作用
代码:
def init(self, inplanes, planes, stride=1, downsample=None):
#这里省略了其它东西
self.downsample = downsampledef forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return outdownsample参数作用:
这个参数代表的应该不是一个数,而是某个函数,可用于调整维度。
如果上一个ResidualBlock的输出维度和当前的ResidualBlock的维度不一样,那就对这个x进行downSample操作,如果维度一样,直接加就行了,这时直接out+=residual。
(参考知乎:https://www.zhihu.com/question/68247422)什么是downsample/下采样?:
缩小图像(或称为下采样(subsampled)或降采样(downsampled))的主要目的有两个:1、使得图像符合显示区域的大小;2、生成对应图像的缩略图。放大图像(或称为上采样(upsampling)或图像插值(interpolating))的主要目的是放大原图像,从而可以显示在更高分辨率的显示设备上。
下采样原理:对于一幅图像I尺寸为MN,对其进行s倍下采样,即得到(M/s)(N/s)尺寸的得分辨率图像,当然s应该是M和N的公约数才行,如果考虑的是矩阵形式的图像,就是把原始图像s*s窗口内的图像变成一个像素,这个像素点的值就是窗口内所有像素的均值。
上采样原理:图像放大几乎都是采用内插值方法,即在原有图像像素的基础上在像素点之间采用合适的插值算法插入新的元素。
(参考:https://blog.csdn.net/stf1065716904/article/details/78450997) -
self.modules()的作用
def __init__(self, block, layers, num_classes=1000): self.inplanes = 64 super(ResNet, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0]) self.layer2 = self._make_layer(block, 128, layers[1], stride=2) self.layer3 = self._make_layer(block, 256, layers[2], stride=2) self.layer4 = self._make_layer(block, 512, layers[3], stride=2) self.avgpool = nn.AvgPool2d(7, stride=1) self.fc = nn.Linear(512 * block.expansion, num_classes) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
最后一小段代码的意思是:如果用到的模块是nn.Conv2d,则用nn.init.kaiming_normal_函数初始化m.weight;如果用到的模块是nn.BatchNorm2d,则用nn.init.constant_分别初始化m.weight和m.bias。
-
形参kwargs
当定义的函数形参中有*arg时,说明:
调用函数时,多出的tuple(eg.1,2,3)的整体(eg.(1,2,3))构成arg;
当定义的函数形参中有**kwarg时,说明:
调用函数时,多出的dict元素(eg.keyword1 = “bar”, keyword2 = “foo”)的整体(eg.{‘keyword2’: ‘foo’, ‘keyword1’: ‘bar’})构成kwarg。
(参考简书:https://www.jianshu.com/p/e0d4705e8293)