本文采用Efficientnet_b3作为主干网络替换unet的下采样部分,使网络提取特征更强大
将红色框的信息替换修改
搭建上采样block
import torch
import torch.nn as nn
import torchvision.models as models
from torchsummary import summary
#基本的block
class DecoderBlock(nn.Module):
def __init__(self,
in_channels=512,
n_filters=256,
kernel_size=3,
is_deconv=False,
):
super().__init__()
if kernel_size == 3:
conv_padding = 1
elif kernel_size == 1:
conv_padding = 0
# B, C, H, W -> B, C/4, H, W
self.conv1 = nn.Conv2d(in_channels,
in_channels // 4,
kernel_size,
padding=1,bias=False)
self.norm1 = nn.BatchNorm2d(in_channels // 4)
self.relu1 = nn.ReLU(inplace=True)
# B, C/4, H, W -> B, C/4, H, W
if is_deconv == True: