ResNet图解
nn.Module详解
1. Pytorch上搭建ResNet-18
1.1 ResNet block子模块
import torch
from torch import nn
from torch.nn import functional as F
class ResBlk(nn.Module):
"""
ResNet block子模块
"""
def __init__(self, ch_in, ch_out, stride = 1):
# super(ResBlk, self).__init__() # python2写法
# python3写法
super().__init__()
self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3,
stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, # 输出通道不变
stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(ch_out)
self.extra = nn.Sequential()
# 如果输入和输出的通道不一致,或其步长不为 1,需要将二者转成一致
if ch_out != ch_in:
# 将x的维度[b, ch_in, h, w] => [b, ch_out, h, w]
self.extra = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1,
stride=stride),
nn.BatchNorm2d(ch_out)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out = self.extra(x) + out
out = F.relu(out)
return out
1.2 ResNet18主模块
class ResNet18(nn.Module):
"""
主模块
"""
def __init__(self):
super(ResNet18, self)