残差模块示例:
'''Residual Block残差模块'''
class ResidualBlock(nn.Module):
def __init__(self,in_channel=64,out_channel=64,k=3,s=1):
super(ResidualBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=k, stride=s, padding=1),
nn.BatchNorm2d(in_channel),
nn.PReLU(), # 自动学习斜率的LeakyReLU
nn.Conv2d(in_channel, out_channel, kernel_size=k, stride=s, padding=1),
nn.BatchNorm2d(in_channel)
)
def forward(self,x):
x = x + self.conv(x) # 残差精华&#x