U2Net使用方法和实现多类别语义分割模型改造

作者的碎碎念:U2Net是用来实现SOD的语义分割,本篇论文会介绍算法内容、主要代码、使用方法,以及如何将二分类语义分割修改为多类别语义模型。如果只想知道怎么训练自己的数据集,或者如何修改网络,可以通过目录进行跳转。
欢迎点赞、评论或收藏❤️


(一)相关链接

  1. 论文名称
    《U2-Net: Going Deeper with Nested U-Structure for Salient Object Detection》
  2. github链接
    https://github.com/xuebinqin/U-2-Net
  3. paper
    https://arxiv.org/pdf/2005.09007.pdf

(二)算法内容

1. 摘要

  U²-Net是显著物体检测(salient object detection,简写SOD)的一个网络,并且现在已经是Python的抠图工具Rembg的基础算法

  • 什么是SOD?
      SOD是模拟人类视觉感知系统来定位场景中最吸引人的目标,例如人像
  • 算法优点总结
    (1)能获取到更多的上下文信息(RSU块,ReSidual U-blocks)
    (2)增加网络深度但没有增加计算量。并且可以从0开始训练,不用从分类预训练网络中再训练
  • 模型大小
      U2-Net (176.3 MB, 30 FPS on GTX 1080Ti GPU)
      U2-Net†(4.7 MB, 40 FPS)

2. 介绍

  • 现有的SOD网络存在什么问题?
    (1)现有的模式基本都是使用已有的backbone,例如AlexNet、VGG、ResNet。这些基础的网络都是为分类任务而设计的,提取的特征更多是语义特征,而不是定位特征和全局对比的信息。
    (2)耗用大量的资源
    (3)牺牲高分辨率的特征映射来实现更深层次的体系结构
  • U2Net的目标是网络更深、使用的资源和计算量更少、能够保持高分辨率的特征图。怎么做呢?
    (1)用两级的内嵌U型结构,不使用分类的backbone
    (2)新型的网络结构更深、能获取高分辨率图像、不增加内存和计算量

3. 网络架构

  • 卷积结构和RSU结构比对
    在这里插入图片描述

(1)( a ) Plain convolution blockPLN
     ( b ) Residual-like block RES
     ( c ) Dense-like block DSE
     ( d ) Inception-like block INC
     ( e ) Our residual U-blockRSU
(2)(a)到( c )是典型的卷积结构,用了1x1和3x3的卷积,感受野太小,只能用来获取local feature
(3)(d)用了空洞卷积增大了感受野,但是需要大的内存和计算资源
(4)RSU-L模块,(L代表层数),Cin:输入通道,Cout:输出通道,M:RSU内部通道

  • 开销比对
    在这里插入图片描述
    RSU的开销(overhead)不大,因为都是下采样,DSE和INC比较大
  • 残差结构比对
    在这里插入图片描述
    (1)残差块:H(x) = F2(F1(x))+x,H(x)是x的映射,F1和F2是卷积操作【对应两个weight layer】
    (2)RSU:HRSU (x) = U(F1(x))+F1(x),RSU和残差不同的地方,是将卷积替换成像Unet的U型结构U-block,原来的输入x替换成F1(x)【weight layer之后】
  • 网络架构
    在这里插入图片描述

  U-Net-like这种结构本来就有,只不过是级联起来,Uxn Net,而作者提出来的是 Un Net,用内嵌(nested)结构而不是级联结构
(1)结构特点:11个stage,每个stage都是RSU结构
   🔸 a six stages encoder
   🔸a five stages decoder
   🔸a saliency map fusion module attached with the decoder stages and the last encoder stage
(2)编码器:
   🔹En_1、En_2、En_3、En_4(即前四个)用到的RSU层数是 RSU-7、 RSU-6、 RSU-5、 RSU-4,层数越多,尺度信息越丰富
   🔹En-5和En-6用了RSU-4F,用了空洞卷积,保证了输入输出是相同的分辨率
(3)解码器:
   De-5也是用了RSU-4F,和En-5、En-6类似
(4)融合模块(saliency map fusion module):
   编码器和解码器的输出,经过3x3卷积和sigmoid,upsample,输出了6个概率热力图:S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1) ,用1x1卷积进行融合,产生了S_fuse

4. loss函数

在这里插入图片描述
✅总Loss等于所有loss之和,包括S_side(6)、S_side(5)、S_side(4)、S_side(3)、S_side(2)、S_side(1),和融合的S_fuse
在这里插入图片描述
✅每一层的S_side(x)的loss,使用了二分类交叉熵损失函数

5. 作者实验结果

在这里插入图片描述
Red, Green, and Blue indicate the best, second best and third best performance
在这里插入图片描述

(三)如何训练自己的数据

1. 标注

用labelme标注图片,生成json文件
在这里插入图片描述

2. mask图像

将json文件转换为mask图片,背景黑色,物体白色,下面是转换代码:

import cv2
import json
import numpy as np
import os
import sys


def func(file:str) -> np.ndarray:
    with open(file, mode='r', encoding="utf-8") as f:
        configs = json.load(f)
    shapes = configs["shapes"]

    png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)

    for shape in shapes:
        cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (255,255,255))

    return png


if  __name__ == "__main__":

    if len(sys.argv) != 3:
        raise ValueError("json文件或目录 输出路径")

    if os.path.isdir(sys.argv[1]):
        for file in os.listdir(sys.argv[1]):
            cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(file)[0]+".png" ), func(os.path.join(sys.argv[1], file)))
    else:
        cv2.imwrite(os.path.join(sys.argv[2], os.path.splitext(os.path.basename(sys.argv[1]))[0]+".png"), func(sys.argv[1]))

在这里插入图片描述

转换的mask图像

3. 训练数据集格式

1️⃣在工程目录创建目录:train_data/DUTS/DUTS-TR/DUTS-TR/
2️⃣在第一步骤创建的目录上,创建目录im_aug,将原图放在这
3️⃣在第一步骤创建的目录上,创建目录gt_aug,将转换的mask图放在这

4. 配置文件修改

  打开u2net_train.py,一般可以设置这几项:
  model_name = ‘u2net’ # 用u2net或者u2netp模型进行训练
  epoch_num = 100000 # 训练轮次
  batch_size_train = 12 # batchsize
  save_frq = 2000 # 每2000个iter保存一个模型

5. 训练命令

python u2net_train.py

6. 测试命令

python u2net_test.py

(四)多类别语义分割

  作者提供的代码只实现了二分类的语义分割,U2Net是否可以用来做多类别的语义分割?答案是可以了,下面提供了将二分类语义分割转换为多类别语义分割的方法

2023.11.14新增完整代码在资源,0积分即可下载,有需要的自取
https://download.csdn.net/download/jin__9981/88533519

1. 实现思路

🔺项目背景:图片有两个类别,分别是螺丝钉和位移线
🔺类别:两个类别+背景,num_class = 3,如果有更多类别,则是n+1类,1是背景
🔺mask图片:二分类时,填充的是0和255;多分类,不同类别可以填充为0(背景)、1(螺丝钉)、2(位移线),所以最多只能分出0~255个类别。查看3个类别的mask,因为像素值只有0、1、2,肉眼看基本是一张黑色图像
🔺模型输出:三个类别,输出三个通道,如[3, 320, 320],每一个通道代表一个类别

2. 修改方法

(1)获取多类别训练mask脚本

import cv2
import json
import numpy as np
import os
import sys


def func(file):
    with open(file, mode='r', encoding="utf-8") as f:
        configs = json.load(f)
    shapes = configs["shapes"]

    png = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)

    for shape in shapes:
        label = shape['label']
        if label == 'lm':
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (1,1,1))
        else:
            cv2.fillPoly(png, [np.array(shape["points"], np.int32)], (2,2,2))

    return png


if  __name__ == "__main__":
    json_dir = "./train_data/labels_json"
    
    save_dir = './train_data/masks'


    for file in os.listdir(json_dir):
        print(file)
        png = func(os.path.join(json_dir, file))
        print(png.shape)
        save_path = save_dir+'/'+os.path.splitext(file)[0]+".png"
        cv2.imwrite(save_path, png)
        print(save_path)

(2)data_loader.py
   class ToTensor(object)和class ToTensorLab(object)这两个类中,有对label进行归一化操作,去除该操作,因为计算loss的时候,多类别换成交叉熵损失函数,它本身包含了softmax操作
在这里插入图片描述
(3)model/u2net.py
   修改模型输出,作者在class U2NETP(nn.Module)和class U2NET(nn.Module)这两个类用了sigmoid函数,需要修改为直接输出,原因同上

# return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
return d0, d1, d2, d3, d4, d5, d6

(4)u2net_train.py
   修改损失函数和模型输出通道,将损失函数由原来的BCELoss,修改为CrossEntropyLoss,并设置模型的输出通道和类别一致

# bce_loss = nn.BCELoss(size_average=True)  # 注释
ce_loss = nn.CrossEntropyLoss()  # 添加
# def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 注释
#     loss0 = bce_loss(d0, labels_v)
#     loss1 = bce_loss(d1, labels_v)
#     loss2 = bce_loss(d2, labels_v)
#     loss3 = bce_loss(d3, labels_v)
#     loss4 = bce_loss(d4, labels_v)
#     loss5 = bce_loss(d5, labels_v)
#     loss6 = bce_loss(d6, labels_v)

#     loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
#     print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
#     loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
#     loss6.data.item()))

#     return loss0, loss

def muti_ce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):  # 添加
    loss0 = ce_loss(d0, labels_v)
    loss1 = ce_loss(d1, labels_v)
    loss2 = ce_loss(d2, labels_v)
    loss3 = ce_loss(d3, labels_v)
    loss4 = ce_loss(d4, labels_v)
    loss5 = ce_loss(d5, labels_v)
    loss6 = ce_loss(d6, labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
    loss0.data.item(), loss1.data.item(), loss2.data.item(), loss3.data.item(), loss4.data.item(), loss5.data.item(),
    loss6.data.item()))

    return loss0, loss
# ------- 3. define model --------
# define the net
n_class = 3
if (model_name == 'u2net'):
    net = U2NET(3, n_class)
elif (model_name == 'u2netp'):
    net = U2NETP(3, n_class)

4. 测试

   该例子中,存在三个类别,分别是背景、螺丝钉、位移线,对应模型三个通道的输出,但模型输出为概率值,如何获取到真实的类别,以及将类别用不同颜色表示出来?可以用下面这个脚本实现模型推理和输出结果图

import os
import cv2
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset

from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):
    ma = torch.max(d)
    mi = torch.min(d)

    dn = (d-mi)/(ma-mi)

    return dn

def save_output(image_name,pred,d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()

    im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split(os.sep)[-1]
    image = io.imread(image_name)
    imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

    pb_np = np.array(imo)

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]

    imo.save(d_dir+imidx+'.png')

def main():

    # --------- 1. get image path and name ---------
    model_name='u2net'#u2netp

    num_class = 3

    image_dir = os.path.join(os.getcwd(), 'test_data', 'ls_test_images')
    prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results_ls' + os.sep)
    model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, 'u2net_bce_itr_1000_train_1.046126_tar_0.124982.pth')

    img_name_list = glob.glob(image_dir + os.sep + '*')
    print(img_name_list)

    # --------- 2. dataloader ---------
    #1. dataloader
    test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
                                        lbl_name_list = [],
                                        transform=transforms.Compose([RescaleT(320),
                                                                      ToTensorLab(flag=0)])
                                        )
    test_salobj_dataloader = DataLoader(test_salobj_dataset,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=1)

    # --------- 3. model define ---------
    if(model_name=='u2net'):
        print("...load U2NET---173.6 MB")
        net = U2NET(3,num_class)
    elif(model_name=='u2netp'):
        print("...load U2NEP---4.7 MB")
        net = U2NETP(3,num_class)

    if torch.cuda.is_available():
        net.load_state_dict(torch.load(model_dir))
        net.cuda()
    else:
        net.load_state_dict(torch.load(model_dir, map_location='cpu'))
    net.eval()

    # --------- 4. inference for each image ---------
    for i_test, data_test in enumerate(test_salobj_dataloader):

        print("inferencing:",img_name_list[i_test].split(os.sep)[-1])

        inputs_test = data_test['image']

        image = cv2.imread(img_name_list[i_test])
        image_name = os.path.basename(img_name_list[i_test])

        inputs_test = inputs_test.type(torch.FloatTensor)

        if torch.cuda.is_available():
            inputs_test = Variable(inputs_test.cuda())
        else:
            inputs_test = Variable(inputs_test)

        d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
        d1 = d1.squeeze(dim=0)    # torch.Size([1, 3, 320, 320]) -> torch.Size([3, 320, 320])
        
        d1 = F.softmax(d1, dim=0)   # [3, 320, 320] 
        # print(d1[0, :, :])

        predict_np = torch.argmax(d1, dim=0, keepdim=True)
        # print(predict_np.shape)  # [1, 320, 320],3个类别,对应3个通道,获取概率值最高的下标

        predict_np = predict_np.cpu().detach().numpy().squeeze()   # 转到cpu设备

        predict_np = cv2.resize(predict_np, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)  # resize和原图一样的大小
        
        r = predict_np.copy()
        b = predict_np.copy()
        g = predict_np.copy()

        cls = dict([(1, (0, 0, 255)),
                    (2, (255, 0, 255)),
                    (3, (0, 255, 0)),
                    (4, (255, 0, 0)),
                    (5, (255, 255, 0))])
        for c in cls:
            r[r == c] = cls[c][0]
            g[g == c] = cls[c][1]
            b[b == c] = cls[c][2]

        rgb = np.zeros((image.shape[0], image.shape[1], 3))
        # print('类别', np.unique(predict_np))
        rgb[:, :, 0] = r
        rgb[:, :, 1] = g
        rgb[:, :, 2] = b

        im = Image.fromarray(rgb.astype(np.uint8))
        im.save('./test_data/my_results_2/' + str(image_name)[:-4] + '.png')

        del d1,d2,d3,d4,d5,d6,d7

if __name__ == "__main__":
    main()

5. 训练测试效果

   经过少量数据的训练测试,证明U2Net可以用来做多类别语义分割
输入图片

输入测试图片

在这里插入图片描述

模型测试效果

撒花完结🌟🌟🌟

  • 11
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
下面是一个简单的PyQt5界面实现U2Net图像分割的例子,使用PyTorch实现。 ``` import sys import os import numpy as np from PIL import Image from PyQt5.QtWidgets import QApplication, QMainWindow, QLabel, QPushButton, QFileDialog from PyQt5.QtGui import QPixmap import torch import torchvision.transforms as transforms from model.u2net import U2NET class MainWindow(QMainWindow): def __init__(self): super().__init__() # 创建UI界面 self.initUI() # 加载模型 self.model = U2NET() self.model.load_state_dict(torch.load("u2net.pth", map_location=torch.device('cpu'))) self.model.eval() def initUI(self): # 设置窗口标题和大小 self.setWindowTitle("U2Net Image Segmentation") self.setGeometry(100, 100, 800, 600) # 创建标签和按钮 self.label = QLabel(self) self.label.setGeometry(25, 50, 750, 450) self.label.setStyleSheet("border: 1px solid black;") self.button = QPushButton("Select Image", self) self.button.setGeometry(25, 525, 150, 50) self.button.clicked.connect(self.selectImage) self.button2 = QPushButton("Segment Image", self) self.button2.setGeometry(200, 525, 150, 50) self.button2.clicked.connect(self.segmentImage) def selectImage(self): # 打开文件对话框,选择要处理的图像 options = QFileDialog.Options() options |= QFileDialog.DontUseNativeDialog fileName, _ = QFileDialog.getOpenFileName(self,"QFileDialog.getOpenFileName()", "","All Files (*);;Images (*.png *.jpg *.jpeg)", options=options) if fileName: # 加载图像并显示在标签上 pixmap = QPixmap(fileName) pixmap = pixmap.scaled(750, 450) self.label.setPixmap(pixmap) # 将图像转换为PyTorch tensor格式 self.input_image = Image.open(fileName).convert("RGB") self.transform = transforms.Compose([transforms.Resize((320, 320)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) self.input_tensor = self.transform(self.input_image).unsqueeze(0) def segmentImage(self): # 对选择的图像进行分割 with torch.no_grad(): output_tensor = self.model(self.input_tensor) # 将输出转换为PIL Image格式 output_tensor = output_tensor.squeeze().numpy() output_tensor = np.where(output_tensor > 0.5, 1.0, 0.0) output_image = Image.fromarray((output_tensor * 255).astype(np.uint8)).convert("L") # 显示分割结果 output_pixmap = QPixmap.fromImage(ImageQt(output_image)) output_pixmap = output_pixmap.scaled(750, 450) self.label.setPixmap(output_pixmap) if __name__ == "__main__": # 创建应用程序和主窗口 app = QApplication(sys.argv) mainWindow = MainWindow() mainWindow.show() sys.exit(app.exec_()) ``` 在上面的代码中,我们首先创建了一个`MainWindow`类,它继承自`QMainWindow`类,并重写了`initUI`方法来创建UI界面。在`initUI`方法中,我们创建了一个标签和两个按钮,其中一个用于选择图像,另一个用于对图像进行分割。 在选择图像按钮的回调函数`selectImage`中,我们使用`QFileDialog`打开一个文件对话框,让用户选择要处理的图像。然后,我们使用`PIL`库来加载图像,并将其转换为PyTorch tensor格式。在转换过程中,我们使用了`transforms`模块来对图像进行缩放、标准化等预处理操作。 在对图像进行分割的按钮回调函数`segmentImage`中,我们将输入张量传递给已加载的U2Net模型,并得到输出张量。然后,我们将输出张量转换为PIL Image格式,并将其显示在标签上。在转换过程中,我们使用了NumPy来将输出张量转换为二值图像,使用`PIL`库将其转换为灰度图像,并使用`QPixmap`将其转换为Qt图像格式。 最后,我们在`__main__`函数中创建了应用程序和主窗口,并调用`show`方法来显示窗口。
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值