本文采用Efficientnet_b3作为主干网络替换unet的下采样部分,使网络提取特征更强大
将红色框的信息替换修改
搭建上采样block
import torch
import torch.nn as nn
import torchvision.models as models
from torchsummary import summary
#基本的block
class DecoderBlock(nn.Module):
def __init__(self,
in_channels=512,
n_filters=256,
kernel_size=3,
is_deconv=False,
):
super().__init__()
if kernel_size == 3:
conv_padding = 1
elif kernel_size == 1:
conv_padding = 0
# B, C, H, W -> B, C/4, H, W
self.conv1 = nn.Conv2d(in_channels,
in_channels // 4,
kernel_size,
padding=1,bias=False)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
self.relu1 = nn.ReLU(inplace=True)
# B, C/4, H, W -> B, C/4, H, W
if is_deconv == True:
self.deconv2 = nn.ConvTranspose2d(in_channels // 4,
in_channels // 4,
3,
stride=2,
padding=1,
output_padding=conv_padding,bias=False)
else:
up_kwargs={'mode': 'bilinear', 'align_corners': True}
self.deconv2 = nn.Upsample(scale_factor=2,**up_kwargs)
self.norm2 = nn.BatchNorm2d(in_channels // 4)
self.relu2 = nn.ReLU(inplace=True)
# B, C/4, H, W -> B, C, H, W
self.conv3 = nn.Conv2d(in_channels // 4,
n_filters,
kernel_size,
padding=conv_padding,bias=False)
self.norm3 = nn.BatchNorm2d(n_filters)
self.relu3 = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.deconv2(x)
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
x = self.norm3(x)
x = self.relu3(x)
return x
下采样使用efficientnet_b3作为主干网络
class Eb3net(nn.Module):
def __init__(self,
num_classes,
num_channels=3,
is_deconv=False,
decoder_kernel_size=3,
):
super().__init__()
self.mean = (0.485, 0.456, 0.406)
self.std = (0.229, 0.224, 0.225)
filters = [32, 48, 136, 1536 ,40]
efficientnet_b3 = models.efficientnet_b3()
self.base_size=512
self.crop_size=512
if num_channels == 3:
self.firstconv = efficientnet_b3.features[0][0]
else:
self.firstconv = nn.Conv2d(num_channels, 40, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3),bias=False)
self.firstbn = efficientnet_b3.features[0][1]
self.firstsilu = efficientnet_b3.features[0][2] #128
self.encoder1 = efficientnet_b3.features[1:3] #64
self.encoder2 = efficientnet_b3.features[3:4] #32
self.encoder3 = efficientnet_b3.features[4:6] #16
self.encoder4 = efficientnet_b3.features[6:] #8
# Decoder
self.center = DecoderBlock(in_channels=filters[3],
n_filters=filters[3],
kernel_size=decoder_kernel_size,
is_deconv=is_deconv)
self.decoder4 = DecoderBlock(in_channels=filters[3] + filters[2],
n_filters=filters[2],
kernel_size=decoder_kernel_size,
is_deconv=is_deconv)
self.decoder3 = DecoderBlock(in_channels=filters[2] + filters[1],
n_filters=filters[1],
kernel_size=decoder_kernel_size,
is_deconv=is_deconv)
self.decoder2 = DecoderBlock(in_channels=filters[1] + filters[0],
n_filters=filters[0],
kernel_size=decoder_kernel_size,
is_deconv=is_deconv)
self.decoder1 = DecoderBlock(in_channels=filters[0] + filters[4],
n_filters=filters[0],
kernel_size=decoder_kernel_size,
is_deconv=is_deconv)
self.finalconv = nn.Sequential(nn.Conv2d(filters[0], 32, 3, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Dropout2d(0.1, False),
nn.Conv2d(32, num_classes, 1))
def require_encoder_grad(self, requires_grad):
blocks = [self.firstconv,
self.encoder1,
self.encoder2,
self.encoder3,
self.encoder4]
for block in blocks:
for p in block.parameters():
p.requires_grad = requires_grad
def forward(self, x):
# stem
x = self.firstconv(x)
x = self.firstbn(x)
x = self.firstsilu(x)
# Encoder
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
center = self.center(e4)
d4 = self.decoder4(torch.cat([center, e3], 1))
d3 = self.decoder3(torch.cat([d4, e2], 1))
d2 = self.decoder2(torch.cat([d3, e1], 1))
d1 = self.decoder1(torch.cat([d2, x], 1))
f= self.finalconv(d1)
return x
summary(Eb3net(num_channels=1,num_classes=2), input_size=[(1, 256, 256)], batch_size=2, device="cpu")