03 U2net

目录

一、 理论知识

1. 网络架构

2. RSU

3. 显著特征融合模块

4. 损失计算​编辑

二、 代码实现

0. DUTS数据集

1. transforms.ToTensor()

2. train

3. 验证

4. 预测


原文链接https://blog.csdn.net/qq_37541097/article/details/126255483

一、 理论知识

U2Net是针对Salient Object Detetion(SOD)即显著性目标检测任务。

显著性目标检测任务与语义分割任务相似,只不过显著性目标检测任务是二分类任务,它的任务是将图片中最吸引人的目标或区域分割出来,故只有前景和背景两类。

1. 网络架构

在大的UNet中嵌入了一堆小UNet

2. RSU

En_1En_2En_3En_4De_1De_2De_3De_4采用的是同一种Block,只是深度不同。

Block就是论文中提出的ReSidual U-block简称RSU(也就是小Unet)

  • En_1De_1采用的是RSU-7En_2De_2采用的是RSU-6En_3De_3采用的是RSU-5En_4De_4采用的是RSU-4
  • En_5En_6De_5三个模块采用的是RSU-4FRSU-4FRSU-4两者结构并不相同
  • 带参数d的卷积层全部是膨胀卷积,d为膨胀系数

深度为7的Block ---------- RSU-7

RSU-4F将采样层全部替换成了膨胀卷积:

3. 显著特征融合模块

saliency map fusion module即显著特征融合模块 ---------- 最终输出模块

对照1中网络结构图

  • 收集De_1、De_2、De_3、De_4、De_5以及En_6的输出;
  • 分别通过一个3x3的卷积层得到channel为1的特征图;
  • 双线性插值缩放到输入图片大小得到Sup1、Sup2、Sup3、Sup4、Sup5和Sup6;
  • 6个特征图进行Concat拼接
  • 通过一个1x1的卷积层以及Sigmiod激活函数得到最终的预测概率图

4. 损失计算

 评价指标:

二、 代码实现

显著性目标检测只区分背景和前景,本项目使用的数据集是DUTS数据集

target:黑白图二维(H,W)------- 背景:0 ;   前景:255 ;  边缘地方:小的数值

U2Net没有torch官方实现;需要使用时修改项目中的代码

0. DUTS数据集

目录结构如下:

├── DUTS-TR
│      ├── DUTS-TR-Image: 该文件夹存放所有训练集的图片
│      └── DUTS-TR-Mask: 该文件夹存放对应训练图片的GT标签(Mask蒙板形式)
│
└── DUTS-TE
       ├── DUTS-TE-Image: 该文件夹存放所有测试(验证)集的图片
       └── DUTS-TE-Mask: 该文件夹存放对应测试(验证)图片的GT标签(Mask蒙板形式)

mask只区分前景背景:(如下图)

 

该数据集没有用到调色板!!!!!

自定义DUTS数据集,详见my_dataset.py文件

1. transforms.ToTensor()

针对image:

  • 转变形状,将H,W,C转变成C,H,W
  • 转变数据格式,返回的tensor格式
  • 归一化,可以将数据的范围变成[0,1]----所有像素值÷255

针对target:

分割类任务,target是二维(H,W)

  • 经过.ToTensor()后,自动增加维度,变为(1,H,W)
  • 转变数据格式,返回的tensor格式
  • 归一化,可以将数据的范围变成[0,1]----所有像素值÷255

2. train

由于torch官方没有实现该模型,加载作者训练好的.pth文件

# ********************************************* 加载.pth参数
# 建立model
model = u2net_full()
# 查看model的初始化参数值
old_dic = model.state_dict()
# 加载.pth参数文件
weight_path = './u2net_full.pth'
pre_dic = torch.load(weight_path, map_location=device)
# 将.pth参数加载到model中,只会将字典名称完全一样以及shape相同的加载进去
# 返回model中未加载成功的参数   以及   .pth中多余的参数(与model不匹配)------strict=False
missing_keys, unexpected_keys = model.load_state_dict(pre_dic, strict=False)
print(missing_keys, unexpected_keys)
# 再次查看model的参数值,检查是否已经更换成功
new_dic = model.state_dict()
model.to(device)

损失计算:

# inputs就是model预测的结果-----多个特征图查看原理部分损失计算
def criterion(inputs, target):
    losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]
    total_loss = sum(losses)  # 求和
    return total_loss

3. 验证

需要建立两个评价指标-------  MeanAbsoluteError  和  F1Score

mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
...
...
def evaluate(,,,,):
    model.eval()
    mae_metric = utils.MeanAbsoluteError()
    f1_metric = utils.F1Score()
    ...
    with torch.no_grad():
        for images, targets in data_loader:
            images, targets = images.to(device), targets.to(device)
            output = model(images)
            mae_metric.update(output, targets)
            f1_metric.update(output, targets)
    return mae_metric, f1_metric
class MeanAbsoluteError(object):
    def __init__(self):
        self.mae_list = []

    def update(self, pred: torch.Tensor, gt: torch.Tensor):
        batch_size, c, h, w = gt.shape
        assert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."
        resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)
        error_pixels = torch.sum(torch.abs(resize_pred - gt), dim=(1, 2, 3)) / (h * w)
        self.mae_list.extend(error_pixels.tolist())

    def compute(self):
        mae = sum(self.mae_list) / len(self.mae_list)
        return mae

    def gather_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        gather_mae_list = []
        for i in all_gather(self.mae_list):
            gather_mae_list.extend(i)
        self.mae_list = gather_mae_list

    def __str__(self):
        mae = self.compute()
        return f'MAE: {mae:.3f}'


class F1Score(object):
    """
    refer: https://github.com/xuebinqin/DIS/blob/main/IS-Net/basics.py
    """

    def __init__(self, threshold: float = 0.5):
        self.precision_cum = None
        self.recall_cum = None
        self.num_cum = None
        self.threshold = threshold

    def update(self, pred: torch.Tensor, gt: torch.Tensor):
        batch_size, c, h, w = gt.shape
        assert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."
        resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)
        gt_num = torch.sum(torch.gt(gt, self.threshold).float())

        pp = resize_pred[torch.gt(gt, self.threshold)]  # 对应预测map中GT为前景的区域
        nn = resize_pred[torch.le(gt, self.threshold)]  # 对应预测map中GT为背景的区域

        pp_hist = torch.histc(pp, bins=255, min=0.0, max=1.0)
        nn_hist = torch.histc(nn, bins=255, min=0.0, max=1.0)

        # Sort according to the prediction probability from large to small
        pp_hist_flip = torch.flipud(pp_hist)
        nn_hist_flip = torch.flipud(nn_hist)

        pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
        nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)

        precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)
        recall = pp_hist_flip_cum / (gt_num + 1e-4)

        if self.precision_cum is None:
            self.precision_cum = torch.full_like(precision, fill_value=0.)

        if self.recall_cum is None:
            self.recall_cum = torch.full_like(recall, fill_value=0.)

        if self.num_cum is None:
            self.num_cum = torch.zeros([1], dtype=gt.dtype, device=gt.device)

        self.precision_cum += precision
        self.recall_cum += recall
        self.num_cum += batch_size

    def compute(self):
        pre_mean = self.precision_cum / self.num_cum
        rec_mean = self.recall_cum / self.num_cum
        f1_mean = (1 + 0.3) * pre_mean * rec_mean / (0.3 * pre_mean + rec_mean + 1e-8)
        max_f1 = torch.amax(f1_mean).item()
        return max_f1

    def reduce_from_all_processes(self):
        if not torch.distributed.is_available():
            return
        if not torch.distributed.is_initialized():
            return
        torch.distributed.barrier()
        torch.distributed.all_reduce(self.precision_cum)
        torch.distributed.all_reduce(self.recall_cum)
        torch.distributed.all_reduce(self.num_cum)

    def __str__(self):
        max_f1 = self.compute()
        return f'maxF1: {max_f1:.3f}'

4. 预测

...
threshold = 0.5
origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device)
...



pred = model(img)   
pred = torch.squeeze(pred).to("cpu").numpy()  # [1, 1, H, W] -> [H, W]
# 生成mask结果
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)     
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]    
# 应该就是将原图中背景像素点变为0,前景像素点不变,  没去仔细查此行功能,待解决
# 保存
cv2.imwrite("pred_result.png", cv2.cvtColor(seg_img.astype(np.uint8), cv2.COLOR_RGB2BGR))

预测一张图片时,生成图如下:

训练以及验证时,数据集中的标签: 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值