论文:Residual Attention Network for Image Classification
论文地址
有两种结构
def attention56():
return Attention([1, 1, 1])
def attention92():
return Attention([1, 2, 3])
原文的结构
这个56和92怎么计算的呢?
源代码
"""residual attention network in pytorch
[1] Fei Wang, Mengqing Jiang, Chen Qian, Shuo Yang, Cheng Li, Honggang Zhang, Xiaogang Wang, Xiaoou Tang
Residual Attention Network for Image Classification
https://arxiv.org/abs/1704.06904
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
#"""The Attention Module is built by pre-activation Residual Unit [11] with the
#number of channels in each stage is the same as ResNet [10]."""
class PreActResidualUnit(nn.Module):
"""PreAct Residual Unit
Args:
in_channels: residual unit input channel number
out_channels: residual unit output channel numebr
stride: stride of residual unit when stride = 2, downsample the featuremap
"""
def __init__(self, in_channels, out_channels, stride):
super().__init__()
bottleneck_channels = int(out_channels / 4)
self.residual_function = nn.Sequential(
#1x1 conv
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, bottleneck_channels, 1, stride),
#3x3 conv
nn.BatchNorm2d(bottleneck_channels),
nn.ReLU(inplace=True),
nn.Conv2d(bottleneck_channels, bottleneck_channels, 3, padding=1),
#1x1 conv
nn.BatchNorm2d(bottleneck_channels),
nn.ReLU(inplace=True),
nn.Conv2d(bottleneck_channels, out_channels, 1)
)
self.shortcut = nn.Sequential()
if stride != 2 or (in_channels != out_channels):
self.shortcut = nn.Conv2d(in_channels, out_channels, 1, stride=stride)
def forward(self, x):
res = self.residual_function(x)
shortcut = self.shortcut(x)
return res + shortcut
class AttentionModule1(nn.Module):
def __init__(self, in_channels, out_channels, p=1, t=2, r=1):
super().__init__()
#"""The hyperparameter p denotes the number of preprocessing Residual
#Units before splitting into trunk branch and mask branch. t denotes
#the number of Residual Units in trunk branch. r denotes the number of
#Residual Units between adjacent pooling layer in the mask branch."""
assert in_channels == out_channels
self.pre = self._make_residual(in_channels, out_channels, p)
self.trunk = self._make_residual(in_channels, out_channels, t)
self.soft_resdown1 = self._make_residual(in_channels, out_channels, r)
self.soft_resdown2 = self._make_residual(in_channels, out_channels, r)
self.soft_resdown3 = self._make_residual(in_channels, out_channels, r)
self.soft_resdown4 = self._make_residual(in_channels, out_channels, r)
self.soft_resup1 = self._make_residual(in_channels, out_channels, r)
self.soft_resup2 = self._make_residual(in_channels, out_channels, r)
self.soft_resup3 = self._make_residual(in_channels, out_channels, r)
self.soft_resup4 = self._make_residual(in_channels, out_channels, r)
self.shortcut_short = PreActResidualUnit(in_channels, out_channels, 1)
self.shortcut_long = PreActResidualUnit(in_channels, out_channels, 1)
self.sigmoid = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.Sigmoid()
)
self.last = self._make_residual(in_channels, out_channels, p)
def forward(self, x):
###We make the size of the smallest output map in each mask branch 7*7 to be consistent
#with the smallest trunk output map size.
###Thus 3,2,1 max-pooling layers are used in mask branch with input size 56 * 56, 28 * 28, 14 * 14 respectively.
x = self.pre(x)
input_size = (x.size(2), x.size(3))
x_t = self.trunk(x)
#first downsample out 28
x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x_s = self.soft_resdown1(x_s)
#28 shortcut
shape1 = (x_s.size(2), x_s.size(3))
shortcut_long = self.shortcut_long(x_s)
#seccond downsample out 14
x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x_s = self.soft_resdown2(x_s)
#14 shortcut
shape2 = (x_s.size(2), x_s.size(3))
shortcut_short = self.soft_resdown3(x_s)
#third downsample out 7
x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x_s = self.soft_resdown3(x_s)
#mid
x_s = self.soft_resdown4(x_s)
x_s = self.soft_resup1(x_s)
#first upsample out 14
x_s = self.soft_resup2(x_s)
x_s = F.interpolate(x_s, size=shape2)
x_s += shortcut_short
#second upsample out 28
x_s = self.soft_resup3(x_s)
x_s = F.interpolate(x_s, size=shape1)
x_s += shortcut_long
#thrid upsample out 54
x_s = self.soft_resup4(x_s)
x_s = F.interpolate(x_s, size=input_size)
x_s = self.sigmoid(x_s)
x = (1 + x_s) * x_t
x = self.last(x)
return x
def _make_residual(self, in_channels, out_channels, p):
layers = []
for _ in range(p):
layers.append(PreActResidualUnit(in_channels, out_channels, 1))
return nn.Sequential(*layers)
class AttentionModule2(nn.Module):
def __init__(self, in_channels, out_channels, p=1, t=2, r=1):
super().__init__()
#"""The hyperparameter p denotes the number of preprocessing Residual
#Units before splitting into trunk branch and mask branch. t denotes
#the number of Residual Units in trunk branch. r denotes the number of
#Residual Units between adjacent pooling layer in the mask branch."""
assert in_channels == out_channels
self.pre = self._make_residual(in_channels, out_channels, p)
self.trunk = self._make_residual(in_channels, out_channels, t)
self.soft_resdown1 = self._make_residual(in_channels, out_channels, r)
self.soft_resdown2 = self._make_residual(in_channels, out_channels, r)
self.soft_resdown3 = self._make_residual(in_channels, out_channels, r)
self.soft_resup1 = self._make_residual(in_channels, out_channels, r)
self.soft_resup2 = self._make_residual(in_channels, out_channels, r)
self.soft_resup3 = self._make_residual(in_channels, out_channels, r)
self.shortcut = PreActResidualUnit(in_channels, out_channels, 1)
self.sigmoid = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.Sigmoid()
)
self.last = self._make_residual(in_channels, out_channels, p)
def forward(self, x):
x = self.pre(x)
input_size = (x.size(2), x.size(3))
x_t = self.trunk(x)
#first downsample out 14
x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x_s = self.soft_resdown1(x_s)
#14 shortcut
shape1 = (x_s.size(2), x_s.size(3))
shortcut = self.shortcut(x_s)
#seccond downsample out 7
x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x_s = self.soft_resdown2(x_s)
#mid
x_s = self.soft_resdown3(x_s)
x_s = self.soft_resup1(x_s)
#first upsample out 14
x_s = self.soft_resup2(x_s)
x_s = F.interpolate(x_s, size=shape1)
x_s += shortcut
#second upsample out 28
x_s = self.soft_resup3(x_s)
x_s = F.interpolate(x_s, size=input_size)
x_s = self.sigmoid(x_s)
x = (1 + x_s) * x_t
x = self.last(x)
return x
def _make_residual(self, in_channels, out_channels, p):
layers = []
for _ in range(p):
layers.append(PreActResidualUnit(in_channels, out_channels, 1))
return nn.Sequential(*layers)
class AttentionModule3(nn.Module):
def __init__(self, in_channels, out_channels, p=1, t=2, r=1):
super().__init__()
assert in_channels == out_channels
self.pre = self._make_residual(in_channels, out_channels, p)
self.trunk = self._make_residual(in_channels, out_channels, t)
self.soft_resdown1 = self._make_residual(in_channels, out_channels, r)
self.soft_resdown2 = self._make_residual(in_channels, out_channels, r)
self.soft_resup1 = self._make_residual(in_channels, out_channels, r)
self.soft_resup2 = self._make_residual(in_channels, out_channels, r)
self.shortcut = PreActResidualUnit(in_channels, out_channels, 1)
self.sigmoid = nn.Sequential(
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.Sigmoid()
)
self.last = self._make_residual(in_channels, out_channels, p)
def forward(self, x):
x = self.pre(x)
input_size = (x.size(2), x.size(3))
x_t = self.trunk(x)
#first downsample out 14
x_s = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
x_s = self.soft_resdown1(x_s)
#mid
x_s = self.soft_resdown2(x_s)
x_s = self.soft_resup1(x_s)
#first upsample out 14
x_s = self.soft_resup2(x_s)
x_s = F.interpolate(x_s, size=input_size)
x_s = self.sigmoid(x_s)
x = (1 + x_s) * x_t
x = self.last(x)
return x
def _make_residual(self, in_channels, out_channels, p):
layers = []
for _ in range(p):
layers.append(PreActResidualUnit(in_channels, out_channels, 1))
return nn.Sequential(*layers)
class Attention(nn.Module):
"""residual attention netowrk
Args:
block_num: attention module number for each stage
"""
def __init__(self, block_num, class_num=100):
super().__init__()
self.pre_conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1)
self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2)
self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3)
self.stage4 = nn.Sequential(
PreActResidualUnit(1024, 2048, 2),
PreActResidualUnit(2048, 2048, 1),
PreActResidualUnit(2048, 2048, 1)
)
self.avg = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(2048, 100)
def forward(self, x):
x = self.pre_conv(x)
x = self.stage1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.avg(x)
x = x.view(x.size(0), -1)
x = self.linear(x)
return x
def _make_stage(self, in_channels, out_channels, num, block):
layers = []
layers.append(PreActResidualUnit(in_channels, out_channels, 2))
for _ in range(num):
layers.append(block(out_channels, out_channels))
return nn.Sequential(*layers)
def attention56():
return Attention([1, 1, 1])
def attention92():
return Attention([1, 2, 3])
代码过于冗长
完全没有耐心看
就是论文的图画得太复杂,才不想看图
结果代码也更加复杂
怎么办?
打印网络如下:
Attention(
(pre_conv): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
(stage1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(64, 64, kernel_size=(1, 1), stride=(2, 2))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(64, 256, kernel_size=(1, 1), stride=(2, 2))
)
(1): AttentionModule1(
(pre): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(trunk): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(1): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown3): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown4): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup3): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup4): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
(shortcut_short): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut_long): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(sigmoid): Sequential(
(0): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
(last): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(stage2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(256, 128, kernel_size=(1, 1), stride=(2, 2))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2))
)
(1): AttentionModule2(
(pre): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(trunk): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
(1): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown3): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup3): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
(shortcut): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
(sigmoid): Sequential(
(0): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
(last): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(stage3): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(512, 256, kernel_size=(1, 1), stride=(2, 2))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2))
)
(1): AttentionModule3(
(pre): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
(trunk): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(1): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resdown2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup1): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
(soft_resup2): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
(shortcut): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(sigmoid): Sequential(
(0): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(last): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
)
)
)
)
(stage4): Sequential(
(0): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(2, 2))
(1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2))
)
(1): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(2048, 2048, kernel_size=(1, 1), stride=(1, 1))
)
(2): PreActResidualUnit(
(residual_function): Sequential(
(0): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1))
(1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(2): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1))
)
(shortcut): Conv2d(2048, 2048, kernel_size=(1, 1), stride=(1, 1))
)
)
(avg): AdaptiveAvgPool2d(output_size=1)
(linear): Linear(in_features=2048, out_features=100, bias=True)
)
查找竟然出现了167次Conv2d,44次shortcut,不知道该怎么看。
于是这样看,已经给出了两种结构
92-56 = 36
3个模块的层数之和是36
再看怎么堆叠的
self.pre_conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.stage1 = self._make_stage(64, 256, block_num[0], AttentionModule1)
self.stage2 = self._make_stage(256, 512, block_num[1], AttentionModule2)
self.stage3 = self._make_stage(512, 1024, block_num[2], AttentionModule3)
self.stage4 = nn.Sequential(
PreActResidualUnit(1024, 2048, 2),
PreActResidualUnit(2048, 2048, 1),
PreActResidualUnit(2048, 2048, 1)
)
self.avg = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(2048, 100)
1个pre_conv 和 4个stage以及linear层
再看stage
layers = []
layers.append(PreActResidualUnit(in_channels, out_channels, 2))
for _ in range(num):
layers.append(block(out_channels, out_channels))
这是由PreActResidualUnit + Block组成
所以算出1个Block应该是12层,而PreActResidualUnit有3个卷积层
这样就可以把数字加起来了
1 个卷积,初始化------------------------------------1层
3个stage,(12+3)* 3 ---------------------------45层
stage4,3个PreActResidualUnit-----------------9层
最后全连接层,---------------------------------------1层
attention 56完成
而另一种结构
这样就可以把数字加起来了
1 个卷积,初始化-------------------------------------1层
3个stage,15 +(12 * 2 + 3) + (12 * 3 + 3) --81层
stage4,3个PreActResidualUnit模块------------9层
最后全连接层------------------------------------------1层
刚好是92层
如果只是巧合算出,请大神指出错误。