手搓Unet部分的一些尝试

在写Unet网络的部分之前需要知道VGG16网络的架构,因为Unet借鉴了VGG16前面部分的结构。

Unet借鉴了除了全连接层外,即最后一个最大池化层之前的所有结构,以下为Unet网络结构。

然后根据Unet的网络结构图,大致可以写得代码如下。为了方便,卷积块和反卷积块的结构都写得比较简单。

def conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_channels),       #定义了一个二维批量归一化层,标准化每个小批量的特征图
        nn.ReLU(inplace=True)               #就地操作,节约内存
    )

def up_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
        nn.ReLU(inplace=True)
    )

from torchvision.models import vgg16_bn
class UNet(nn.Module):
    def __init__(self, pretrained=True, out_channels=12):
        super().__init__()

        #编码器,将 VGG16 模型的前34层分为5个卷积块,分别提取低级到高级特征
        self.encoder = vgg16_bn(pretrained=pretrained).features
        self.block1 = nn.Sequential(*self.encoder[:6])              #包含5个层Conv1_1, ReLU,Conv1_2, ReLU,MaxPool1,输出通道数64
        self.block2 = nn.Sequential(*self.encoder[6:13])            #输出通道数128
        self.block3 = nn.Sequential(*self.encoder[13:20])           #包含7个层Conv3_1, ReLU,Conv3_2, ReLU,Conv3_3, ReLU,MaxPool3,256
        self.block4 = nn.Sequential(*self.encoder[20:27])           #512
        self.block5 = nn.Sequential(*self.encoder[27:34])           #512

        #连接层
        self.bottleneck = nn.Sequential(*self.encoder[34:])
        self.conv_bottleneck = conv(512, 1024)                      #将通道数从512增加到1024

        #解码器,恢复原图尺寸
        self.up_conv6 = up_conv(1024, 512)                          #上采样到512
        self.conv6 = conv(512 + 512, 512)                           #上采样后的特征图与block5特征图拼接,并卷积处理
        self.up_conv7 = up_conv(512, 256)
        self.conv7 = conv(256 + 512, 256)
        self.up_conv8 = up_conv(256, 128)
        self.conv8 = conv(128 + 256, 128)
        self.up_conv9 = up_conv(128, 64)
        self.conv9 = conv(64 + 128, 64)
        self.up_conv10 = up_conv(64, 32)
        self.conv10 = conv(32 + 64, 32)
        self.conv11 = nn.Conv2d(32, out_channels, kernel_size=1)    #输出层

反向传播部分代码。

    def forward(self, x):
        block1 = self.block1(x)
        block2 = self.block2(block1)
        block3 = self.block3(block2)
        block4 = self.block4(block3)
        block5 = self.block5(block4)

        bottleneck = self.bottleneck(block5)
        x = self.conv_bottleneck(bottleneck)

        x = self.up_conv6(x)
        x = torch.cat([x, block5], dim=1)       #将上采样后的特征图x与编码器相应层的特征图 block5 进行拼接(跳跃连接),保留高分辨率的细节信息
        x = self.conv6(x)

        x = self.up_conv7(x)
        x = torch.cat([x, block4], dim=1)
        x = self.conv7(x)

        x = self.up_conv8(x)
        x = torch.cat([x, block3], dim=1)
        x = self.conv8(x)

        x = self.up_conv9(x)
        x = torch.cat([x, block2], dim=1)
        x = self.conv9(x)

        x = self.up_conv10(x)
        x = torch.cat([x, block1], dim=1)
        x = self.conv10(x)

        x = self.conv11(x)
        return x

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值