U2Net论文解读及代码测试

论文名称: U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection
论文地址: https://arxiv.org/pdf/2005.09007.pdf
论文作者:Xuebin Qin, Zichen Zhang, Chenyang Huang, Masood Dehghan, Osmar R. Zaiane and Martin Jagersand University of Alberta, Canada
Github地址U-2-Net

一、论文解读

1. 前言

设计了一个简单而强大的深度网络架构U2-Net,用于显著目标检测(SOD)。我们的U2-Net的体系结构是一个两层嵌套的U结构。

该设计有以下两点优势:
(1)它能够捕捉更多的上下文信息,因为提出了RSU(ReSidual U-blocks)结构,融合了不同尺度的感受野的特征;
(2)它增加了整个架构的深度但并没有显著增加计算成本,因为在这些RSU块中使用了池化操作。

这种架构使我们能够从头开始训练深度网络,而无需使用图像分类任务中的backbone

2. 分析

显著目标检测(Salient Object Detection, SOD)的目的是分割出图像中最具吸引力的目标。它在视觉跟踪、图像分割等领域有着广泛的应用。随着深度卷积神经网络,尤其是图像分割中全卷积网络的发展,显著目标检测得到了发展。

目前现状:

大多数的SOD网络有一个共性,就是注重利用现有的主干提取深层特征,比如Alexnet, VGG, ResNet, ResNeXt, DenseNet等。这些主干最终都是为图像分类任务而设计的,他们提取的特征代表语义,而不是局部细节和全局对比信息,但这对显著性检测至关重要。

他们需要在ImageNet数据集上进行预训练,如果目标数据与ImageNet具有不同的分布,则会比较低效。

当前SOD模型中有哪些问题呢?

  • 网络结构复杂,这是由于在现有主干网络上添加特征聚合模块,从这些模型中提取多层显著性特征;
  • 现有主干网络通常通过牺牲高分辨率的特征映射来实现更深层次的体系结构。

因此,后续问题是:我们能否在保持高分辨率特征地图的同时,以较低的内存和计算成本进行深入研究?

U2Net网络,解决了上述两个问题:

  • 第一,该网络是一个两层嵌套的U型结构,没有使用图像分类的预训练主干模型,可以从头训练;
  • 第二,新的体系结构允许网络更深入、获得高分辨率,而不会显著增加内存和计算成本。在底层,设计了一个新的RSU,能够在不降低特征映射分辨率的情况下提取级内多尺度特征;在顶层,有一个类似于U-Net的结构,每一stage由一个RSU块填充。

相关工作总结

主要的研究方向在于多层次与多尺度特征提取上。

  • 多层次深层特征集成方法:主要集中在开发更好的多层次特征聚合策略上。
  • 多尺度特征提取方法:旨在设计新的模块,从主干网络获取的特征中提取局部和全局信息。

3. Residual U-blocks

对于显著目标检测和其他分割任务来说,局部和全局上下文信息都非常重要。在现代CNN设计中,如VGGResNetDenseNet等,1×13×3的小型卷积滤波器是最常用的特征提取元件。它们是受欢迎的,因为它们需要较少的存储空间并且计算效率高。
在这里插入图片描述
上图(a)-(c) 显示了具有最小感受野的现有卷积块。由于1x13x3滤波器的感受野太小而无法捕捉全局信息,因此底层的输出特征图只包含局部特征。为了在高分辨率浅层特征图上获取更多的全局信息,最直接的想法是扩大感受野。图(d)中显示了一个inception like block(不知如何翻译),试图通过使用空洞卷积来扩大感受野以提取局部特征和非局部特征。然而,在原始分辨率的输入特征图上进行多次扩展卷积(尤其是初始阶段)需要大量的计算和内存资源。

RSU构成

U-Net的启发,提出了一种新的RSU来捕捉阶段内多尺度特征。上图中(e)显示了RSU-L(Cin, M, Cout)结构, 其中L是编码器层数 C i n C_{in} Cin C o u t C_{out} Cout表示输入和输出通道,M表示RSU内部层中的通道数。因此,我们的RSU主要由三个部分组成:

  1. 输入卷积层:它将输入特征图x (HxWxCin)转成一个具有Cout通道数的中间图F1(x),这是提取局部特征的普通卷积层。
  2. 以中间特征图F1(x)为输入,学习提取和编码多尺度上下文信息 U ( F 1 ( x ) ) U(F_1(x)) U(F1(x))U表示如图2(e)所示的U-NetL越大,RSU越深,池化操作越多,更大的感受野以及更丰富的局部和全局特征。配置此参数可以从具有任意空间分辨率的输入特征图中提取多尺度特征。从梯度降采样特征图中提取多尺度特征,并通过逐步上采样、合并和卷积等方法将其编码到高分辨率特征图中。这一过程减少了大尺度直接上采样造成的细节损失。
  3. 通过求和: F 1 ( x ) + U ( F 1 ( x ) ) F_1(x)+U(F_1(x)) F1(x)+U(F1(x)) 融合局部特征和多尺度特征。

残差块与RSU的对比
在这里插入图片描述
上图表示残差块与RSU对比,主要设计区别在于,RSUU-Net代替了普通的单流卷积,并用一个权重层构成的局部特征代替了原始特征: H R S U ( x ) = U ( F 1 ( x ) ) + F 1 ( x ) H_{RSU}(x)=U(F_1(x)) + F_1(x) HRSU(x)=U(F1(x))+F1(x),其中U代表图2(e)所示的多层U型结构。这种设计变化使网络能够直接从每个残差块的多个尺度中提取特征。值得注意的是,由于U结构的计算开销很小,因为大多数操作都应用于下采样的特征映射。图4中展示了RSU和图(a)-(d)中其他特征提取模型之间的计算成本比较。
图4
图4. 计算成本比较:根据将尺寸为320x320x3的输入特征图传输到320x320x6输出特征图的基础上计算。

4. U2-Net架构

问题:如何组合多个RSU以达到好的效果呢?

级联模式

通常多个类似U-Net按顺序堆叠,以建立级联模型,并可归纳为(Uxn-Net)n是重复U-Net模块的数目,带来的问题是计算和内存开销被n放大了。如DocUNet, CU-Net网络等,如下图所示,为DocUNet网络的构成:
在这里插入图片描述
图: DocUNet网络结构

U型嵌套模式

作者提出一种不同的U型结构叠加模型。我们的指数表示法是指嵌套的U型结构,而不是级联叠加。理论上,可以将指数n设为任意正整数,实现单级或多级嵌套U型结构。但是,嵌套层太多的体系结构过于复杂,无法在实际中实现和应用。

我们将n设为2来构建U2-Net,是一个两层嵌套的U型结构,如图5所示。它的顶层是一个由11 stages(图5中的立方体)组成的大U型结构,每一stage由一个配置良好的RSU填充。因此,嵌套的U结构可以更有效的提取stage内的多尺度特征和聚集阶段的多层次特征。
在这里插入图片描述

如图5所示,U2-Net网络由三部分构成:

  • 六级编码器
  • 五级解码器
  • 与解码器和最后一级编码器相连的显著图融合模型

(Ⅰ)编码器En_1, En_2En_3En_4阶段中,分别使用RSU-7RSU-6RSU-5RSU-4RSU结构。前面的数字如7, 6, 5, 4RSU的高度LL通常根据输入特征图的空间分辨率进行配置。在En_5En_6中,特征图的分辨率相对较低,进一步对这些特征图进行下采样会导致有用上下文的丢失。因此,RSU-5RSU-6阶段,使用RSU-4F,其中F表示RSU是一个扩展的版本,其中我们用扩展卷积来代替合并和上采样操作。这意味着RSU-4F的所有中间特征图都与其输入的特征图具有相同的分辨率。

(Ⅱ)解码阶段具有与En_6中对称编码阶段相似的结构。在De_5中,还使用了扩展板的RSU-4F,这与在编码阶段En_5En_6中使用的类似。每个解码器阶段将来自前一级的上采样特征映射和来自其对称编码器阶段的特征映射的级联作为输入,见图5。

(Ⅲ)最后一部分是显著图融合模块,用于生成显著概率图。U2-Net网络首先通过3x3卷积和Sigmoid函数从En_6De_5De_4De_3De_2De_1生成6个输出显著概率图 S s i d e ( 6 ) S_{side}^{(6)} Sside(6) S s i d e ( 5 ) S_{side}^{(5)} Sside(5) S s i d e ( 4 ) S_{side}^{(4)} Sside(4) S s i d e ( 3 ) S_{side}^{(3)} Sside(3) S s i d e ( 2 ) S_{side}^{(2)} Sside(2) S s i d e ( 1 ) S_{side}^{(1)} Sside(1)。然后,将输出的显著图的逻辑图(卷积输出,Sigmoid函数之前)向上采样至与输入图像大小一致,并通过级联操作相融合,然后通过1x1卷积层和一个Sigmoid函数,以生成最终的显著性概率映射图 S f u s e S_{fuse} Sfuse

总结

U2-Net网络的设计允许具有丰富多尺度特性和相对较低的计算和内存成本低 深层架构。该结构只建立在RSU块上,没有使用任何特性分类的预训练主干网络,因此是灵活的,可适应不同的工作环境,性能损失很小。

本文中,使用不同的滤波器配置提供两种情况下的U2-Net:普通版本的U2-Net(176.3MB)和较小版本的U2-Net(4.7MB)。

5. 损失函数Loss

在这里插入图片描述
其中 l s i d e ( m ) ( M = 6 , 表 示 图 5 中 S u p 1 , . . . S u p 6 ) l_{side}^{(m)} (M=6, 表示图5中Sup1, ...Sup6) lside(m)(M=6,5Sup1,...Sup6)是输出显著图 S s i d e ( m ) S_{side}^{(m)} Sside(m)的loss, l f u s e l_{fuse} lfuse是最后的融合输出显著图的loss。 ω s i d e ( m ) ω_{side}^{(m)} ωside(m) ω f u s e ω_{fuse} ωfuse是每个loss项的权重。对于每一项,我们使用标准二进制交叉熵来计算损失:
在这里插入图片描述
其中, ( r , c ) (r,c) (r,c)为像素坐标, ( H , W ) (H,W) (H,W)为图像大小:高度和宽度。 P G ( r , c ) P_{G(r,c)} PG(r,c) P S ( r , c ) P_{S(r,c)} PS(r,c)分别表示GT像素值和预测的显著概率图。
训练过程试图最小化整个损失。测试过程中,我们选择最后融合结果 l f u s e l_{fuse} lfuse作为最终的显著性图。

6. 数据集

  1. 训练数据集
    使用的是DUTS-TR,它是DUTS数据集的一部分。DUTS-TR一共包含10553张图像。目前,它是用于显著性目标检测的最大和最常用的训练数据集。通过水平翻转来扩充数据集,共获得21106个训练图像。
  2. 评估数据集
    使用6个常用的基准数据集来评估,包括:DUT-OMRONDUTS-TEHKU-ISECSSDPASCAL-SSOD
    • DUT-OMRON包括5168幅图像,其中大多数包含1到2个结构复杂的前景对象。
    • DUTS数据集由两部分组成:DUTS-TRDUTS-TEDUTS-TE包含5019幅图像。
    • HKU-IS包含4447幅图像,有多个前景图像。
    • ECSSD包含1000个结构复杂的图像,其中许多包含大型前景对象。
    • PASCAL-S包含850幅前景复杂、背景杂乱的图像。
    • SOD只包含300幅图像。但是很具有挑战新,因为它最初是为图像分割而设计的,且许多图像的对比度低,或者包含与图像边界重叠的复杂前景对象。

7. 评估准则

为了综合评估这些概率图的质量,我们采用了六种度量方法,包括:(1)Precision-Recall(PR)曲线;(2)最大F-measure( m a x F β maxF_β maxFβ),(3)平均绝对误差(MAE),(4)加权F-measure( F β w F_β^w Fβw),(5)structure measure( S m S_m Sm),(6)relaxed F-measure of boundary( r e l a x F β b relaxF_β^b relaxFβb)。

看一组结果:
在这里插入图片描述

二、代码分析

代码github: NathanUA/U^2-Net

该网络是为了显著性目标检测而设计的,就在前几天(2020/11/21)作者更新了该网络的另一个应用,即人脸肖像画生成。

1、显著性检测

所需库文件:

numpy 1.15.2
scikit-image 0.14.0
python-opencv PIL 5.2.0
PyTorch 0.4.0
torchvision 0.2.1
glob

安装使用

  1. 下载代码
    git clone https://github.com/NathanUA/U-2-Net.git

  2. 下载预训练模型 u2net.pth (176.3 MB)u2netp.pth (4.7 MB) 并将其放入目录'./saved_models/u2net/''./saved_models/u2netp/'

  3. 进入目录U-2-Net, 通过命令分别运行训练过程和测试过程: python u2net_train.pypython u2net_test.py. 两个文件中的'model_name'可以改为'u2net''u2netp'对于不同的模型。

模型地址:
u2net.pth
u2netp.th
模型下载不下来的,可以去U2Net 网络预训练模型u2net.pth下载。

这份代码非常简洁,看起来也十分顺手!

测试显著检测效果,需指定模型名称,测试数据路径,然后,执行命令:
python u2net_test.py

数据加载,模型加载,模型推理,结果的保存等这些步骤在u2net_test.py脚本中均已经提供了,没有太多要说的。
结果如下:
在这里插入图片描述
可以看出,得到的结果非常好,很多细小的毛发处也处理的比较细腻。

2、人脸肖像画生成

最近,研究者又将其应用于人脸肖像画的生成中,并基于 APDrawingGAN 数据集为此类任务训练了新的模型。不管是儿童肖像还是成年男性、成年女性,都能获得相当细致。
先看一下效果:
在这里插入图片描述
安装使用

  1. 克隆仓库到本地
    git clone https://github.com/NathanUA/U-2-Net.git
  2. 下载u2net_portrait.pth模型,并经其放置在'./saved_models/u2net_portrait'
  3. 准备自己的数据并将他们放在'./test_data/test_portrait_images/your_portrait_im'。(当然可以自己指定存放路径,代码中能找到对应路径即可)为了获取足够的肖像细节,输入图像的人头区域应该接近甚至大于512x512.头部背景应该相对干净些。
  4. 通过命令'python u2net_portrait_demo.py'运行预测,结果在'./test_data/test_portrait_images/your_portrait_results'中。

注:预训练模型无法下载的,可以去u2net_portrait.pth处下载。

'python u2net_portrait_demo.py''python u2net_portrait_test.py'的不同之处在于:
u2net_portrait_demo.py中增加了一个简单的人脸检测步骤在肖像生成之前。因为APDrawingGAN测试集被归一化并裁剪为512x512大小只包含头部,而我们自己的数据集可能包含不同分辨率和内容。

因此,python u2net_portrait_demo.py代码将会在给定的图像中检测并裁剪大人脸区域,pad并resize到512x512以喂给网络。
下面的图像展示了如何拿自己的照片产生高质量的肖像图:
在这里插入图片描述

  • 70
    点赞
  • 307
    收藏
    觉得还不错? 一键收藏
  • 18
    评论
下面是一个简单的PyQt5界面实现U2Net图像分割的例子,使用PyTorch实现。 ``` import sys import os import numpy as np from PIL import Image from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QFileDialog from PyQt5.QtGui import QPixmap import torch import torchvision.transforms as transforms from model.u2net import U2NET class MainWindow(QMainWindow): def __init__(self): super().__init__() # 创建UI界面 self.initUI() # 加载模型 self.model = U2NET() self.model.load_state_dict(torch.load("u2net.pth", map_location=torch.device('cpu'))) self.model.eval() def initUI(self): # 设置窗口标题和大小 self.setWindowTitle("U2Net Image Segmentation") self.setGeometry(100, 100, 800, 600) # 创建标签和按钮 self.label = QLabel(self) self.label.setGeometry(25, 50, 750, 450) self.label.setStyleSheet("border: 1px solid black;") self.button = QPushButton("Select Image", self) self.button.setGeometry(25, 525, 150, 50) self.button.clicked.connect(self.selectImage) self.button2 = QPushButton("Segment Image", self) self.button2.setGeometry(200, 525, 150, 50) self.button2.clicked.connect(self.segmentImage) def selectImage(self): # 打开文件对话框,选择要处理的图像 options = QFileDialog.Options() options |= QFileDialog.DontUseNativeDialog fileName, _ = QFileDialog.getOpenFileName(self,"QFileDialog.getOpenFileName()", "","All Files (*);;Images (*.png *.jpg *.jpeg)", options=options) if fileName: # 加载图像并显示在标签上 pixmap = QPixmap(fileName) pixmap = pixmap.scaled(750, 450) self.label.setPixmap(pixmap) # 将图像转换为PyTorch tensor格式 self.input_image = Image.open(fileName).convert("RGB") self.transform = transforms.Compose([transforms.Resize((320, 320)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) self.input_tensor = self.transform(self.input_image).unsqueeze(0) def segmentImage(self): # 对选择的图像进行分割 with torch.no_grad(): output_tensor = self.model(self.input_tensor) # 将输出转换为PIL Image格式 output_tensor = output_tensor.squeeze().numpy() output_tensor = np.where(output_tensor > 0.5, 1.0, 0.0) output_image = Image.fromarray((output_tensor * 255).astype(np.uint8)).convert("L") # 显示分割结果 output_pixmap = QPixmap.fromImage(ImageQt(output_image)) output_pixmap = output_pixmap.scaled(750, 450) self.label.setPixmap(output_pixmap) if __name__ == "__main__": # 创建应用程序和主窗口 app = QApplication(sys.argv) mainWindow = MainWindow() mainWindow.show() sys.exit(app.exec_()) ``` 在上面的代码中,我们首先创建了一个`MainWindow`类,它继承自`QMainWindow`类,并重写了`initUI`方法来创建UI界面。在`initUI`方法中,我们创建了一个标签和两个按钮,其中一个用于选择图像,另一个用于对图像进行分割。 在选择图像按钮的回调函数`selectImage`中,我们使用`QFileDialog`打开一个文件对话框,让用户选择要处理的图像。然后,我们使用`PIL`库来加载图像,并将其转换为PyTorch tensor格式。在转换过程中,我们使用了`transforms`模块来对图像进行缩放、标准化等预处理操作。 在对图像进行分割的按钮回调函数`segmentImage`中,我们将输入张量传递给已加载的U2Net模型,并得到输出张量。然后,我们将输出张量转换为PIL Image格式,并将其显示在标签上。在转换过程中,我们使用了NumPy来将输出张量转换为二值图像,使用`PIL`库将其转换为灰度图像,并使用`QPixmap`将其转换为Qt图像格式。 最后,我们在`__main__`函数中创建了应用程序和主窗口,并调用`show`方法来显示窗口。
评论 18
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值