目录
原文链接https://blog.csdn.net/qq_37541097/article/details/126255483
一、 理论知识
U2Net是针对Salient Object Detetion(SOD)即显著性目标检测任务。
显著性目标检测任务与语义分割任务相似,只不过显著性目标检测任务是二分类任务,它的任务是将图片中最吸引人的目标或区域分割出来,故只有前景和背景两类。
1. 网络架构
在大的UNet中嵌入了一堆小UNet
2. RSU
En_1
、En_2
、En_3
、En_4
、De_1
、De_2
、De_3
、De_4
采用的是同一种Block,
只是深度不同。
Block
就是论文中提出的ReSidual U-block
简称RSU(也就是小Unet)
En_1
和De_1
采用的是RSU-7
,En_2
和De_2
采用的是RSU-6
,En_3
和De_3
采用的是RSU-5
,En_4
和De_4
采用的是RSU-4
En_5
、En_6
和De_5
三个模块采用的是RSU-4F
,RSU-4F
和RSU-4
两者结构并不相同。- 带参数
d
的卷积层全部是膨胀卷积,d
为膨胀系数
深度为7的Block ---------- RSU-7
RSU-4F
将采样层全部替换成了膨胀卷积:
3. 显著特征融合模块
saliency map fusion module
即显著特征融合模块 ---------- 最终输出模块
对照1中网络结构图看
- 收集De_1、De_2、De_3、De_4、De_5以及En_6的输出;
- 分别通过一个3x3的卷积层得到channel为1的特征图;
- 双线性插值缩放到输入图片大小得到Sup1、Sup2、Sup3、Sup4、Sup5和Sup6;
- 6个特征图进行Concat拼接;
- 通过一个1x1的卷积层以及Sigmiod激活函数得到最终的预测概率图
4. 损失计算
评价指标:
二、 代码实现
显著性目标检测只区分背景和前景,本项目使用的数据集是DUTS数据集
target:黑白图二维(H,W)------- 背景:0 ; 前景:255 ; 边缘地方:小的数值
U2Net没有torch官方实现;需要使用时修改项目中的代码
0. DUTS数据集
目录结构如下:
├── DUTS-TR
│ ├── DUTS-TR-Image: 该文件夹存放所有训练集的图片
│ └── DUTS-TR-Mask: 该文件夹存放对应训练图片的GT标签(Mask蒙板形式)
│
└── DUTS-TE
├── DUTS-TE-Image: 该文件夹存放所有测试(验证)集的图片
└── DUTS-TE-Mask: 该文件夹存放对应测试(验证)图片的GT标签(Mask蒙板形式)
mask只区分前景背景:(如下图)
该数据集没有用到调色板!!!!!
自定义DUTS数据集,详见my_dataset.py文件
1. transforms.ToTensor()
针对image:
- 转变形状,将H,W,C转变成C,H,W
- 转变数据格式,返回的tensor格式
- 归一化,可以将数据的范围变成[0,1]----所有像素值÷255
针对target:
分割类任务,target是二维(H,W)
- 经过.ToTensor()后,自动增加维度,变为(1,H,W)
- 转变数据格式,返回的tensor格式
- 归一化,可以将数据的范围变成[0,1]----所有像素值÷255
2. train
由于torch官方没有实现该模型,加载作者训练好的.pth文件
# ********************************************* 加载.pth参数
# 建立model
model = u2net_full()
# 查看model的初始化参数值
old_dic = model.state_dict()
# 加载.pth参数文件
weight_path = './u2net_full.pth'
pre_dic = torch.load(weight_path, map_location=device)
# 将.pth参数加载到model中,只会将字典名称完全一样以及shape相同的加载进去
# 返回model中未加载成功的参数 以及 .pth中多余的参数(与model不匹配)------strict=False
missing_keys, unexpected_keys = model.load_state_dict(pre_dic, strict=False)
print(missing_keys, unexpected_keys)
# 再次查看model的参数值,检查是否已经更换成功
new_dic = model.state_dict()
model.to(device)
损失计算:
# inputs就是model预测的结果-----多个特征图查看原理部分损失计算
def criterion(inputs, target):
losses = [F.binary_cross_entropy_with_logits(inputs[i], target) for i in range(len(inputs))]
total_loss = sum(losses) # 求和
return total_loss
3. 验证
需要建立两个评价指标------- MeanAbsoluteError 和 F1Score
mae_metric, f1_metric = evaluate(model, val_data_loader, device=device)
...
...
def evaluate(,,,,):
model.eval()
mae_metric = utils.MeanAbsoluteError()
f1_metric = utils.F1Score()
...
with torch.no_grad():
for images, targets in data_loader:
images, targets = images.to(device), targets.to(device)
output = model(images)
mae_metric.update(output, targets)
f1_metric.update(output, targets)
return mae_metric, f1_metric
class MeanAbsoluteError(object):
def __init__(self):
self.mae_list = []
def update(self, pred: torch.Tensor, gt: torch.Tensor):
batch_size, c, h, w = gt.shape
assert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."
resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)
error_pixels = torch.sum(torch.abs(resize_pred - gt), dim=(1, 2, 3)) / (h * w)
self.mae_list.extend(error_pixels.tolist())
def compute(self):
mae = sum(self.mae_list) / len(self.mae_list)
return mae
def gather_from_all_processes(self):
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
torch.distributed.barrier()
gather_mae_list = []
for i in all_gather(self.mae_list):
gather_mae_list.extend(i)
self.mae_list = gather_mae_list
def __str__(self):
mae = self.compute()
return f'MAE: {mae:.3f}'
class F1Score(object):
"""
refer: https://github.com/xuebinqin/DIS/blob/main/IS-Net/basics.py
"""
def __init__(self, threshold: float = 0.5):
self.precision_cum = None
self.recall_cum = None
self.num_cum = None
self.threshold = threshold
def update(self, pred: torch.Tensor, gt: torch.Tensor):
batch_size, c, h, w = gt.shape
assert batch_size == 1, f"validation mode batch_size must be 1, but got batch_size: {batch_size}."
resize_pred = F.interpolate(pred, (h, w), mode="bilinear", align_corners=False)
gt_num = torch.sum(torch.gt(gt, self.threshold).float())
pp = resize_pred[torch.gt(gt, self.threshold)] # 对应预测map中GT为前景的区域
nn = resize_pred[torch.le(gt, self.threshold)] # 对应预测map中GT为背景的区域
pp_hist = torch.histc(pp, bins=255, min=0.0, max=1.0)
nn_hist = torch.histc(nn, bins=255, min=0.0, max=1.0)
# Sort according to the prediction probability from large to small
pp_hist_flip = torch.flipud(pp_hist)
nn_hist_flip = torch.flipud(nn_hist)
pp_hist_flip_cum = torch.cumsum(pp_hist_flip, dim=0)
nn_hist_flip_cum = torch.cumsum(nn_hist_flip, dim=0)
precision = pp_hist_flip_cum / (pp_hist_flip_cum + nn_hist_flip_cum + 1e-4)
recall = pp_hist_flip_cum / (gt_num + 1e-4)
if self.precision_cum is None:
self.precision_cum = torch.full_like(precision, fill_value=0.)
if self.recall_cum is None:
self.recall_cum = torch.full_like(recall, fill_value=0.)
if self.num_cum is None:
self.num_cum = torch.zeros([1], dtype=gt.dtype, device=gt.device)
self.precision_cum += precision
self.recall_cum += recall
self.num_cum += batch_size
def compute(self):
pre_mean = self.precision_cum / self.num_cum
rec_mean = self.recall_cum / self.num_cum
f1_mean = (1 + 0.3) * pre_mean * rec_mean / (0.3 * pre_mean + rec_mean + 1e-8)
max_f1 = torch.amax(f1_mean).item()
return max_f1
def reduce_from_all_processes(self):
if not torch.distributed.is_available():
return
if not torch.distributed.is_initialized():
return
torch.distributed.barrier()
torch.distributed.all_reduce(self.precision_cum)
torch.distributed.all_reduce(self.recall_cum)
torch.distributed.all_reduce(self.num_cum)
def __str__(self):
max_f1 = self.compute()
return f'maxF1: {max_f1:.3f}'
4. 预测
...
threshold = 0.5
origin_img = cv2.cvtColor(cv2.imread(img_path, flags=cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
img = data_transform(origin_img)
img = torch.unsqueeze(img, 0).to(device)
...
pred = model(img)
pred = torch.squeeze(pred).to("cpu").numpy() # [1, 1, H, W] -> [H, W]
# 生成mask结果
pred = cv2.resize(pred, dsize=(w, h), interpolation=cv2.INTER_LINEAR)
pred_mask = np.where(pred > threshold, 1, 0)
origin_img = np.array(origin_img, dtype=np.uint8)
seg_img = origin_img * pred_mask[..., None]
# 应该就是将原图中背景像素点变为0,前景像素点不变, 没去仔细查此行功能,待解决
# 保存
cv2.imwrite("pred_result.png", cv2.cvtColor(seg_img.astype(np.uint8), cv2.COLOR_RGB2BGR))
预测一张图片时,生成图如下:
训练以及验证时,数据集中的标签: