def _make_divisible(v, divisor=8, min_value=None):
# 计算新的通道数divisor=8
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class MobileBlock(nn.Module):
# def __init__(self, in_channels, out_channels, kernal_size, stride, nonLinear, SE, exp_size, dropout_rate=1.0):
def __init__(self, in_channels, out_channels, kernal_size, stride, nonLinear, SE, exp_size):
super(MobileBlock, self).__init__()
self.out_channels = out_channels
self.nonLinear = nonLinear
self.SE = SE
# self.dropout_rate = dropout_rate
padding = (kernal_size - 1) // 2
self.use_connect = stride == 1 and in_channels == out_channels # 残差选择
if self.nonLinear == "RE": # 激活函数选择
activation = nn.ReLU
else:
activation = h_swish
self.conv = nn.Sequential(
nn.Conv2d(in_channels, exp_size, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(exp_size),
activation(inplace=True)
)
self.depth_conv = nn.Sequential(
nn.Conv2d(exp_size, exp_size, kernel_size=kernal_size, stride=stride, padding=padding, groups=exp_size), # 可分离卷积
nn.BatchNorm2d(exp_size),
)
if self.SE:
self.squeeze_block = SqueezeBlock(exp_size) # se模块
self.point_conv = nn.Sequential(
nn.Conv2d(exp_size, out_channels, kernel_size=1, stride=1, padding=0), # 点卷积 1x1卷积
nn.BatchNorm2d(out_channels),
activation(inplace=True)
)
def forward(self, x):
# MobileNetV2
out = self.conv(x)
out = self.depth_conv(out)
# Squeeze and Excite
if self.SE: # se模块计算
out = self.squeeze_block(out)
# point-wise conv
out = self.point_conv(out)
# connection
if self.use_connect: # 残差计算
return x + out
else:
return out
class SqueezeBlock(nn.Module):
def __init__(self, exp_size, divide=4):
super(SqueezeBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Sequential(
nn.Linear(exp_size, exp_size // divide),
nn.ReLU(inplace=True),
nn.Linear(exp_size // divide, exp_size),
h_sigmoid()
)
def forward(self, x):
batch, channels, height, width = x.size()
# out = F.avg_pool2d(x, kernel_size=[height, width]).view(batch, -1)
out = self.avg_pool(x).view(batch, channels)
out = self.dense(out)
out = out.view(batch, channels, 1, 1)
# out = hard_sigmoid(out)
return out * x
def _weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3., inplace=self.inplace) / 6.
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.inplace = inplace
def forward(self, x):
out = F.relu6(x + 3., self.inplace) / 6.
return out * x