Unet,Pix2pix的解析和从零实现

本文介绍了Unet模型的结构和工作原理,该模型结合了卷积和反卷积操作,用于图像分割任务。Unet的特点在于其下采样和上采样的对称结构,以及跳跃连接,以保留不同层次的特征信息。文章通过一个简单的PyTorch实现示例,展示了如何构建一个Unet网络,并解释了反卷积在处理图像尺寸变化时避免棋盘效应的方法。
摘要由CSDN通过智能技术生成

简要介绍
  • Unet名称是由于它长得像U, 前半部分是下采样, feature map越来越多但是越来越小, 然后后半部分是上采样特征图减少但是逐渐变大。同时在下采样和上采样之间有跳连接, 下图是论文中的图。
  • Unet的好处我感觉是:网络层越深得到的特征图,有着更大的视野域,浅层卷积关注纹理特征,深层网络关注本质的那种特征,所以深层浅层特征都是有格子的意义的;另外一点是通过反卷积得到的更大的尺寸的特征图的边缘,是缺少信息的,毕竟每一次下采样提炼特征的同时,也必然会损失一些边缘特征,而失去的特征并不能从上采样中找回,因此通过特征的拼接,来实现边缘特征的一个找回。
    在这里插入图片描述

可以看到输入是一个572的灰度图, 然后最后输出的是388的二分类, 下面分步阐述:

  1. 33卷积两次到64个568的feature map, pooling变为一半, 然后两次33卷积变为128个280的feature map
  2. 重复33卷积的下采样, 然后到1024个2828的feature map为止, 也就是黄线的部分
    在这里插入图片描述
  3. 其次是上采样, 注意到这里上采样一次后是特征图数量减小一半, 大小增加一倍, 从左边跳连接过来的尺寸并不匹配, 比如从1024个28的特征图上采样到512个56的, 左边是64, 这里需要裁一下64的, 应该是pooling
反卷积
  • 这是原来的卷积
    在这里插入图片描述
  • 下面是反卷积
  • 当我们要用到深度学习来生成图像的时候,往往是基于一个低分辨率且具有高维的特征图。我们通常使用反卷积操作来完成此操作,但不幸的是,当卷积核大小不能被步长整除的时候,会出现棋盘现象, 正确的方法是:调整图像大小(使用最近邻插值或双线性插值),然后执行卷积操作。这似乎是一种自然的方法,大致相似的方法在图像超分辨率方面表现良好
    在这里插入图片描述
简易例程

import torch
import torch.nn as nn
import torch.nn.functional as F

class double_conv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1):
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,
                               kernel_size=kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,
                              kernel_size = kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out
    
class deconv2d_bn(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=2,strides=2):
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
                                        kernel_size = kernel_size,
                                       stride = strides,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
    def forward(self,x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out
    
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(1,8)
        self.layer2_conv = double_conv2d_bn(8,16)
        self.layer3_conv = double_conv2d_bn(16,32)
        self.layer4_conv = double_conv2d_bn(32,64)
        self.layer5_conv = double_conv2d_bn(64,128)
        self.layer6_conv = double_conv2d_bn(128,64)
        self.layer7_conv = double_conv2d_bn(64,32)
        self.layer8_conv = double_conv2d_bn(32,16)
        self.layer9_conv = double_conv2d_bn(16,8)
        self.layer10_conv = nn.Conv2d(8,1,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(128,64)
        self.deconv2 = deconv2d_bn(64,32)
        self.deconv3 = deconv2d_bn(32,16)
        self.deconv4 = deconv2d_bn(16,8)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x):
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1,2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2,2)
        
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3,2)
        
        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4,2)
        
        conv5 = self.layer5_conv(pool4)
        
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1,conv4],dim=1)
        conv6 = self.layer6_conv(concat1)
        
        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2,conv3],dim=1)
        conv7 = self.layer7_conv(concat2)
        
        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3,conv2],dim=1)
        conv8 = self.layer8_conv(concat3)
        
        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4,conv1],dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp
    

model = Unet()
inp = torch.rand(10,1,224,224)
outp = model(inp)
print(outp.shape)
==> torch.Size([10, 1, 224, 224])

CycleGAN 和 pix2pix 是两种常用的图像到图像(Image-Image)转换模型,尤其在无监督学习中非常流行。在PyTorch库中实现这两种模型通常涉及到深度学习框架和一些高级的图像处理技术。 1. **Pix2Pix**[^4]: Pix2Pix使用条件生成对抗网络(Conditional GANs),它结合了卷积神经网络(CNN)和U-Net架构。在PyTorch中,可以这样实现: ```python import torch.nn as nn from unet import UNet # 假设你有一个名为UNet的U-Net实现 class Pix2PixModel(nn.Module): def __init__(self, input_channels, output_channels): super(Pix2PixModel, self).__init__() self.netG = UNet(input_channels, output_channels) self.netD = Discriminator(output_channels) # 假设Discriminator是一个预定义的模型 def forward(self, x): fake_B = self.netG(x) return fake_B model = Pix2PixModel(input_channels=3, output_channels=3) # 输入和输出都是RGB图像 ``` 2. **CycleGAN**[^5]: CycleGAN则是无条件的,它通过一个翻译网络(Generator)和一个反向翻译网络(Cycle-GAN中的Discriminators)来训练。PyTorch实现可能包括: ```python class Generator(nn.Module): # ...定义网络结构... class Discriminator(nn.Module): # ...定义网络结构... generator_A2B = Generator() generator_B2A = Generator() adversarial_loss = nn.BCELoss() cycle_loss = nn.L1Loss() def train_step(A, B): # ...执行一个训练步骤,包括生成器和判别器的更新... ``` 训练过程中,CycleGAN还会包含一个额外的损失项来确保生成的图像在经过双向转换后仍能保持原始输入的相似性。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

live_for_myself

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值