最近在跑实验的过程中一直在使用resnet50和resnet34,为了弄清楚网络的结构和原理的实现,打开resnet的源码进行了学习。
残差网络学习的原理
针对神经网络过深而导致的学习准确率饱和甚至是退化现象,resnet通过将若干个卷积层前的输入x直接与经过卷积层卷积学习过的特征进行叠加,假如经过卷积层学习到的特征为H(x),那么经过若干卷积层后得到的特征F(x)=H(x) + x,那该网络需要学习的仅为H(x)。
![28ac307855658d5f87695d9191e1d666.png](https://i-blog.csdnimg.cn/blog_migrate/cef2d5af31cd8736fd8b8f2d50845292.png)
Resnet源码分析
Resnet网络主要是若干个网络的堆叠,Resnet内部实现了两种网络,一个是Basicblock,一个是Bottleneck。
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
ide