卷 积 块 C B M ( c o n v + b a t c h n o r m + m i s h ) 卷积块CBM(conv+batchnorm+mish) 卷积块CBM(conv+batchnorm+mish)
Conv:提取特征
BN:1.防止梯度消失 2.防止过拟合 3.促进收敛
Mish:更优秀的激活函数,相比于其他,可以更有效的防止梯度消失
class Mish(nn.Module):
def __init__(self):
super(Mish, self).__init__()
def forward(self, x):
return x * torch.tanh(F.softplus(x))# F.softplus(x) == torch.log(1+torch.exp(x))
class CBM(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
super(CBM, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.activation = Mish()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.activation(x)
return x
测试
rgb = torch.randn(1, 3, 32, 32) # (batchsize,channel,w,h)
#print(rgb)
print(rgb.shape)
test_downsample_conv = BasicConv(3, 1,3,stride=2)
x = test_downsample_conv(rgb)
#print(x)
print(x.shape)