U-Net网络

U-Net网络

一、基本架构

请添加图片描述

各个箭头的解释:

  1. conv 3 * 3, ReLU:表示通过一个3 * 3的卷积层,并且该层自动附带一个非线性激活层(ReLu
  2. copy and crop:表示进行裁剪然后再进行拼接(在channel的维度上进行拼接)
  3. max pool 2 * 2:表示通过一个2 * 2的最大池化下采样层,(这一个步骤可以通过一个卷积层进行实现,如果使用最大池化下采样层则会导致丢失pixel(像素)信息)
  4. up-conv 2 * 2:表示一个上采样过程,可以使用转置卷积来实现,也可以使用最邻近插值法来实现,由于使用转置卷积会导致出现很多空洞,因此我们使用最邻近差值法。
  5. conv 1 * 1:表示一个卷积核大小为1 * 1 的卷积层,作用主要是改变维度(即channel的大小)

在实际代码中构建网络时,我们一共为U-Net网络构建了三个模块

  1. 蓝色箭头:我们构建为卷积块,并且使用padding直接进行填充,这样做不会使图片的分辨率发生改变。
  2. 红色箭头:我们构建为下采样块,并且使用的是卷积的操作进行的下采样,因为最大池化层会使得丢失太多的图片信息。
  3. 绿色箭头:我们构建为上采样块,并且与灰色箭头一同实现,上下样的过程中,我们使用的是最邻近插值法。

二、理论分析:

论文解读

对于一个高分辨率的图像,如果直接输入网络则会爆显存,因此需要每次将该图像的一小部分输入网络,并且要求每次输入的一小部分需要与之前输入的部分有重叠,这样做可以很好的利用图像的边缘信息。具体方式如下:

在这里插入图片描述

上图展示了将一个1024 * 1024分辨率的图像进行拆分为N个256 * 256分辨率大小的部分,然后再输入到网络中。

预测边缘图像:
在这里插入图片描述

由于该论文用于医学图像分割领域,作者研究发现,对于细胞与细胞之间的区域分割是有一定困难的,因此,作者提出了Pixel-Weight lose weight的一个方案,也就是在细胞与细胞之间的这些背景区,我们给它施加一个更大的权重,而对于大片的背景区,我们就给它施加一个比较小的权重。

实验分析:

由U-Net网络的架构可以看出,网络的核心是构建了三个模块,即:3 * 3的卷积层构成的卷积块、下采样块、上采样块,由于网络多次使用这三个模块,因此我们可以将这三个模块进行封装。

计算卷积后图像的宽度和高度(公式一):

  • I n p u t : ( N , C i n , H ( i n ) , W ( i n ) ) Input:(N, C_{in}, H_(in), W_(in)) Input:(N,Cin,H(in),W(in))

  • O u t P u t : ( N , C ( o u t ) , H ( o u t ) , W ( o u t ) ) OutPut:(N, C_(out), H_(out), W_(out)) OutPut:(N,C(out),H(out),W(out))

    H ( o u t ) = [ H ( i n ) + 2 × p a d d i n g [ 0 ] − d i l a t i o n [ 0 ] × ( k e r n e l s i z e [ 0 ] − 1 ) − 1 s t r i d e [ 0 ] + 1 ] H_(out) = [\frac{H_(in) + 2 \times padding[0] - dilation[0] \times (kernel_{size[0]} - 1) - 1}{stride[0]} + 1] H(out)=[stride[0]H(in)+2×padding[0]dilation[0]×(kernelsize[0]1)1+1]
    ·

    W ( o u t ) = [ W ( i n ) + 2 × p a d d i n g [ 1 ] − d i l a t i o n [ 1 ] × ( k e r n e l s i z e [ 1 ] − 1 ) s t r i d e [ 0 ] − 1 ] W_(out) = [\frac{W_(in) + 2 \times padding[1] - dilation[1] \times (kernel_{size[1]} - 1)}{stride[0]} - 1] W(out)=[stride[0]W(in)+2×padding[1]dilation[1]×(kernelsize[1]1)1]

    参数解释:padding是填充的大小,dilation是空洞卷积的大小(即卷积核各个单元之间有多少个间隔),kernel_size是卷积核的大小。

    空洞卷积:
    在这里插入图片描述

1. 卷积块
  • 首先定义一个Convolution(卷积层),卷积核大小为3 * 3(即:kernel_size = 3),分析U-Net架构图(输入:[1, 572, 572] ==> 输出:[64, 570, 570])可以得到,channel的维度由1 上升到了64,所以定义64个卷积核,由于后面的copy and crop拼接的时候还需要进行裁剪,会导致很麻烦,因此现在的主流的方式是将卷积层加上一个padding,即通过卷积层后不会改变图像的高和宽,并且会在卷积核与 ReLU 之间加上一个BN(Batch normalization),由于没有使用空洞卷积,默认dilatation = 1,由**(公式一)**可以得到stride = 1padding = 1,这样保证了卷积后图像的高度和宽度不会改变。

  • 然后再添加一个Batch normalization层进行归一化处理,这样的好处是加快收敛

  • 再添加一个Dropout层,Dropout 是一种正则化技术,通过在训练过程中随机丢弃一部分神经元的输出,可以减少过拟合并提升模型的泛化能力。

  • 最后添加一个LeakReLU层

    LeakyReLU 函数在处理负值时不像 ReLU 那样完全将其置零,而是允许一小部分负输入信息的线性泄漏。这有助于缓解ReLU 死亡问题,即神经元可能陷入零激活状态,使得模型难以学习。

    数学上,LeakyReLU 函数的定义如下:

    f(x) = max(ax, x)

    其中:

    • x 表示函数的输入,

    • a 是一个小常数(通常是一个小的正值,如 0.01),它代表函数负值部分的斜率。

    由于U-Net网络每次需要添加两次卷积层,因此需要将上述定义的卷积层再次重复一次

卷积块代码

class Conv_Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False),   # 填充模式padding_mode='reflect'表示边界向内复制, 第二个参数out_channel表示卷积核的数量
            nn.BatchNorm2d(out_channel),  # 归一化处理,参数为特征图的通道数
            nn.Dropout(0.3),   # 这条语句的作用是创建一个丢弃比例为0.3的 Dropout 层,也就是30%的输入将被随机置为0。。Dropout 是一种正则化技术,通过在训练过程中随机丢弃一部分神经元的输出,可以减少过拟合并提升模型的泛化能力
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout(0.3),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.layer(x)
2. 下采样块

由于最大池化丢弃了太多的特征,因此我们使用一个3 * 3 的卷积来进行最大池化

  • 首先定义一个3 * 3 的卷积核,并且通过U-Net网络的结构图(输入:[64, 568, 568]==> 输出:[64, 284, 284])可以知到,相当于将图像的宽度和高度进行了减半,因此我们在卷积核中设置padding = 1 stride = 2
  • 然后添加一个Batch Normalization层
  • 最后添加一个LeakReLU层

下采样块代码

class DownSample(nn.Module):
    def __init__(self, channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            # 最大池化时,通道数量不变
            nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, padding_mode='reflect', bias=False),
            # 'reflect' 模式意味着在边缘周围反射输入图像的像素值。这种模式可以减少边缘效应,并且有助于保持特征图的边界信息。
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.layer(x)
3. 上采样块

由于使用转置卷积会导致出现很多空洞,因此我们使用最邻近差值法

  • 首先使用最邻近插值法对输入的特征图进行处理(Pytorch中的方法是:nn.Functional()函数)。
  • 然后使用一个1 * 1的卷积将图像进行升维。
  • 最后将与该层对应的层在channel维度上进行拼接(Pytorch中的方法是:torch.cat())。

上采样块代码

class UpSample(nn.Module):
    def __init__(self, channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel // 2, kernel_size=1, stride=1)

    def forward(self, x, feature_map):
        up = F.interpolate(x, scale_factor=2, mode='nearest')  # 参数解释:scale_factor :变为原来的2倍, mode :使用什么方式,这里为使用最邻近插值法
        out = self.layer(up)
        # 实现拼接
        return torch.cat((out, feature_map), dim=1)  # [N, C, H, W] 在通道的维度进行拼接
U-Net的整体定义
  • 首先定义一个卷积层,后面连接一个下采样层,重复4次。
  • 然后添加一个卷积层。
  • 再添加一个上采样层,后面连接一个卷积层,重复4次。
  • 最后添加一个3 * 3的卷积层,将维度映射为(RGB)3个channel

U-Net整体代码

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.c1 = Conv_Block(3, 64)
        self.d1 = DownSample(64)
        self.c2 = Conv_Block(64, 128)
        self.d2 = DownSample(128)
        self.c3 = Conv_Block(128, 256)
        self.d3 = DownSample(256)
        self.c4 = Conv_Block(256, 512)
        self.d4 = DownSample(512)
        self.c5 = Conv_Block(512, 1024)
        # 开始进行上采样
        self.u1 = UpSample(1024)
        self.c6 = Conv_Block(1024, 512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        # 进行输出
        self.out = nn.Conv2d(64, 3, (3, 3), 1, 1)
        self.Th = nn.Sigmoid()  # 由于我们只需要直到图像的蒙版,只需要知到这个像素是黑的还是白的,因此这是一个二分类问题

    def forward(self, x):
        R1 = self.c1(x)
        R2 = self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        # 进行上采样
        O1 = self.c6(self.u1(R5, R4))  # 进行拼接
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))

三、代码实现:

U_Net_model.py

import torch
from torch import nn
from torch.nn import functional as F

# 构建卷积块
class Conv_Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=(3, 3), stride=1, padding=1, padding_mode='reflect', bias=False),   # 填充模式padding_mode='reflect'表示边界向内复制, 第二个参数out_channel表示卷积核的数量
            nn.BatchNorm2d(out_channel),  # 归一化处理,参数为特征图的通道数
            nn.Dropout(0.3),   # 这条语句的作用是创建一个丢弃比例为0.3的 Dropout 层,也就是30%的输入将被随机置为0。。Dropout 是一种正则化技术,通过在训练过程中随机丢弃一部分神经元的输出,可以减少过拟合并提升模型的泛化能力
            nn.LeakyReLU(),
            nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(out_channel),
            nn.Dropout(0.3),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.layer(x)


# 最大池化下采样(由于最大池化丢弃了太多的特征,因此我们使用一个3 * 3 的卷积来进行最大池化)
class DownSample(nn.Module):
    def __init__(self, channel):
        super(DownSample, self).__init__()
        self.layer = nn.Sequential(
            # 最大池化时,通道数量不变
            nn.Conv2d(channel, channel, kernel_size=3, stride=2, padding=1, padding_mode='reflect', bias=False),
            nn.BatchNorm2d(channel),
            nn.LeakyReLU()
        )

    def forward(self, x):
        return self.layer(x)


# 下采样(由于使用转置卷积会导致出现很多空洞,因此我们使用最邻近差值法)
class UpSample(nn.Module):
    def __init__(self, channel):
        super(UpSample, self).__init__()
        self.layer = nn.Conv2d(channel, channel // 2, kernel_size=1, stride=1)

    def forward(self, x, feature_map):
        up = F.interpolate(x, scale_factor=2, mode='nearest')  # 参数解释:scale_factor :变为原来的2倍, mode :使用什么方式,这里为使用最邻近插值法
        out = self.layer(up)
        # 实现拼接
        return torch.cat((out, feature_map), dim=1)  # [N, C, H, W] 在通道的维度进行拼接


class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.c1 = Conv_Block(3, 64)
        self.d1 = DownSample(64)
        self.c2 = Conv_Block(64, 128)
        self.d2 = DownSample(128)
        self.c3 = Conv_Block(128, 256)
        self.d3 = DownSample(256)
        self.c4 = Conv_Block(256, 512)
        self.d4 = DownSample(512)
        self.c5 = Conv_Block(512, 1024)
        # 开始进行上采样
        self.u1 = UpSample(1024)
        self.c6 = Conv_Block(1024, 512)
        self.u2 = UpSample(512)
        self.c7 = Conv_Block(512, 256)
        self.u3 = UpSample(256)
        self.c8 = Conv_Block(256, 128)
        self.u4 = UpSample(128)
        self.c9 = Conv_Block(128, 64)
        # 进行输出
        self.out = nn.Conv2d(64, 3, 3, 1, 1)
        self.Th = nn.Sigmoid()  # 由于我们只需要直到图像的蒙版,只需要知到这个像素是黑的还是白的,因此这是一个二分类问题

    def forward(self, x):
        R1 = self.c1(x)
        R2 = self.c2(self.d1(R1))
        R3 = self.c3(self.d2(R2))
        R4 = self.c4(self.d3(R3))
        R5 = self.c5(self.d4(R4))
        # 进行上采样
        O1 = self.c6(self.u1(R5, R4))  # 进行拼接
        O2 = self.c7(self.u2(O1, R3))
        O3 = self.c8(self.u3(O2, R2))
        O4 = self.c9(self.u4(O3, R1))

        return self.Th(self.out(O4))


if __name__ == '__main__':
    '''
    定义网络的结构使用的代码,整个U-Net网络
    '''
    x = torch.randn(2, 3, 572, 572)
    net = Unet()
    print(net(x).shape)

utils.py

utils.py文件用于对输入的图片的shape进行处理

由于直接对图像进行缩放会导致图像进行变形,这就导致图像的特征发生了变化,为了保证图像特征的完整性,我们使用蒙版的方法进行设定输入图像的大小,具体方式如下:

  • 首先找到图像中最大的边。
  • 然后利用此边设定一个值为0的方形蒙版。
  • 将图片粘贴到该蒙版上。
  • 对结合后的蒙版进行等比缩放得到需要的图片大小。

在这里插入图片描述

from PIL import Image


# 对图片进行缩放
def keep_image_size_open(path, size=(256, 256)):
    img = Image.open(path)
    # img.size返回的是一个元组,temp获取的是每一张图片的最大长度
    temp = max(img.size)
    # Image.new(mode, size, color),用于创建一个新的图像。color表示图像的初始颜色
    mask = Image.new('RGB', (temp, temp), (0, 0, 0))  
    '''
    mask.paste(im, box, mask=None) 用于将一个图像粘贴到另一个图像上,并可以指定粘贴的位置以及透明度,参数解释:im表示要粘贴的图像,box定义了粘贴位置和大小的矩形框(0, 0)表示从左上角进行粘贴
    '''
    mask.paste(img, (0, 0))  
    mask = mask.resize((size))  # 调整大小
    return mask

if __name__ == '__main__':
    keep_image_size_open("./data/JPEGImages/000033.jpg").show()

My_DataSet.py

import os

from torch.utils.data import Dataset
from utils import *
from torchvision import transforms

# 将数据转换为Tenso类型
transform = transforms.Compose([
    transforms.ToTensor()
])


# 定义数据集(图像分割数据集)
class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.name = os.listdir(os.path.join(path, "SegmentationClass"))

    def __len__(self):
        return len(self.name)

    def __getitem__(self, index):
        segment_name = self.name[index]   # 格式:xxx.png
        # 拼接得到蒙版的地址
        segment_path = os.path.join(self.path, 'SegmentationClass', segment_name)
        # 拼接得到原图的地址
        image_paht = os.path.join(self.path, 'JPEGImages', segment_name.replace('png', 'jpg'))
        # 将蒙版与原图进行读取进来
        segment_image = keep_image_size_open(segment_path)
        image = keep_image_size_open(image_paht)
        return transform(image), transform(segment_image)


if __name__ == '__main__':
    path = './data'
    data = MyDataset(path)
    print(data[0][0].shape)
    print(data[0][1].shape)

train.py

from torch import nn, optim
import torch
from torch.utils.data import DataLoader
from My_DataSet import *
from net import *
from torchvision.utils import save_image


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weight_path = 'params/unet.pth'
data_pth = './Data/VOCdevkit/VOC2007'
save_path = 'train_image'

def main():
    '''
    训练网络使用的代码
    '''
    data_loader = DataLoader(MyDataset(data_pth), batch_size=2, shuffle=True)
    net = Unet().to(device)
    # 读取之前训练的权重
    if os.path.exists(weight_path):
        net.load_state_dict(torch.load(weight_path))
        print("SUCCESSFUL LOAD WEIGHT!")
    else:
        print("NOT SUCCESSFUL LOAD WEIGHT")
	
    # 设置优化器以及损失函数
    opt = optim.Adam(net.parameters())
    loss_fn = nn.BCELoss()

    epochs = 1000
    for epoch in range(epochs):
        for i, (image, segment_image) in enumerate(data_loader):
            image, segment_image = image.to(device), segment_image.to(device)

            out_image = net(image)
            train_loss = loss_fn(out_image, segment_image)

            opt.zero_grad()
            train_loss.backward()
            opt.step()
			
            # 每训练5个图片输出一次损失
            if i % 5 == 0:
                print(f'{epoch}-{i}-train_loss---->>{train_loss.item()}')
			
            # 每训练50个图片更新一次权重
            if i % 50 == 0:
                torch.save(net.state_dict(), weight_path)
                
            # 每训练100个图片
            if i % 100 == 0:
                _image = image[0]
                _segment_image = segment_image[0]
                _out_image = out_image[0]

                img = torch.stack([_image, _segment_image, _out_image], dim=0)
                save_image(img, f'{save_path}/{i}.png')

                

if __name__ == '__main__':
    main()

predict.py

import os.path
import torch
from utils import *
from net import *
from My_DataSet import *
from torchvision.utils import save_image


# 实例化U-Net网络
net = Unet().cuda()

# 读取训练的权重
weights = 'params/unet.pth'
if os.path.exists(weights):
    net.load_state_dict(torch.load(weights))
    print('SUCCESSFULLY')
else:
    print('NO LOADING')

# 输入需要预测的图片的路径
_input = input('please input JPEGImages path:')

# 对图片的格式进行调整
img = keep_image_size_open(_input)
# 指定调用的硬件资源
img_data = transform(img).cuda()
# 在第0维增加一维,因为训练的时候有batch维度,这里需要添加一维
img_data = torch.unsqueeze(img_data, dim=0)
# 得到网络的输出
out = net(img_data)
# 对预测的得到的蒙版进行保存
save_image(out, 'result/result.jpg')
print(out.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值