ResNet网络
ResNet是一种残差网络,可以把它理解为一个模块,这个模块经过堆叠可以构成一个很深的网络。ResNet通过增加残差连接(shortcut connection),显地让网络中的层拟合残差映射(residual mapping)。ResNet网络结构为多个Residual Block的串联
ResNet不再尝试学习x到H(x)的潜在映射,而是学习两者之间的不同,或说残差(residual)。然后,为了计算H(x),可将残差加到输入上。假设残差是F(x)= H(x) - x,我们将尝试学习F(x)+ x,而不是直接学习H(x)。
特点:
- 与纯层的堆叠相比,ResNet多了很多“残差链接”,即shortcut路径,也就是 Residual Block
- ResNet中,所有的Residual Block都没有pooling层,降采样是通过conv的stride实现的
- 通过Average Pooling得到最终的特征,而不是通过全连接层
- 每个卷积层之后都紧接着BatchNorm 层
Summay feature:
ResNet结构非常容易修改和扩展,通过调整block内的channel数量以及堆叠的block数量,就可以很容易地调整网络的宽度和深度,来得到不同表达能力的网络,而不用过多地担心网络的“退化”问题,只要训练数据足够,逐步加深网络,就可以获得更好的性能表现。
创建ResNet网络
### resnet模型
class ResnetbasicBlock(nn.Module):
def __init__(self,in_channels,out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False)
self.bn1 = nn.BatchNorm2d(out_channels) # 标准化
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels) # 标准化
def forward(self,x):
residual = x
out = self.conv1(x)
out = F.relu(self.bn1(out),inplace = True)
out = self.conv2(x)
out = F.relu(self.bn2(out), inplace=True)
out += residual # 输入与最后的输出相加
return F.relu(out)