GitHub地址链接:https://github.com/NathanUA/U-2-Net
这个显著性检测很好用,强烈推荐,建议二分类的任务都来试试,尤其对边缘细节要求比较高的任务。
下面的效果要不是第一张图预测有瑕疵,我都以为预测代码是把标签复制了一下(+_+)
这里的精度我就不评价了,肉眼看就已经能说明问题了
原图
标签
预测结果:
1.数据准备
和传统的语义分割数据集存放目录是一致的,Image和Mask里面是名字一一对应(名字相同,后缀可以不同,代码里可以改后缀识别,jpg png这些都行)的图像—标签对
2.训练
data_loader.py,在222行左右,需要加copy()
# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
#==========================dataset load==========================
class RescaleT(object):
def __init__(self,output_size):
assert isinstance(output_size,(int,tuple))
self.output_size = output_size
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label']
h, w = image.shape[:2]
if isinstance(self.output_size,int):
if h > w:
new_h, new_w = self.output_size*h/w,self.output_size
else:
new_h, new_w = self.output_size,self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
img = transform.resize(image,(self.output_size,self.output_size),mode='constant')
lbl = transform.resize(label,(self.output_size,self.output_size),mode='constant', order=0, preserve_range=True)
return {'imidx':imidx, 'image':img,'label':lbl}
class Rescale(object):
def __init__(self,output_size):
assert isinstance(output_size,(int,tuple))
self.output_size = output_size
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'],sample['label']
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
if isinstance(self.output_size,int):
if h > w:
new_h, new_w = self.output_size*h/w,self.output_size
else:
new_h, new_w = self.output_size,self.output_size*w/h
else:
new_h, new_w = self.output_size
new_h, new_w = int(new_h), int(new_w)
# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image,(new_h,new_w),mode='constant')
lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)
return {'imidx':imidx, 'image':img,'label':lbl}
class RandomCrop(object):
def __init__(self,output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size
def __call__(self,sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']
if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]
h, w = image.shape[:2]
new_h, new_w = self.output_size
top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)
image = image[top: top + new_h, left: left + new_w]
label = label[top: top + new_h, left: left + new_w]
return {'imidx':imidx,'image':image, 'label':label}
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
tmpLbl = np.zeros(label.shape)
image = image/np.max(image)
if(np.max(label)<1e-6):
label = label
else:
label = label/np.max(label)
if image.shape[2]==1:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
else:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
tmpLbl[:,:,0] = label[:,:,0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'imidx':torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}
class ToTensorLab(object):
"""Convert ndarrays in sample to Tensors."""
def __init__(self,flag=0):
self.flag = flag
def __call__(self, sample):
imidx, image, label =sample['imidx'], sample['image'], sample['label']
tmpLbl = np.zeros(label.shape)
if(np.max(label)<1e-6):
label = label
else:
label = label/np.max(label)
# change the color space
if self.flag == 2: # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0],image.shape[1],6))
tmpImgt = np.zeros((image.shape[0],image.shape[1],3))
if image.shape[2]==1:
tmpImgt[:,:,0] = image[:,:,0]
tmpImgt[:,:,1] = image[:,:,0]
tmpImgt[:,:,2] = image[:,:,0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)
# nomalize image to range [0,1]
tmpImg[:,:,0] = (tmpImgt[:,:,0]-np.min(tmpImgt[:,:,0]))/(np.max(tmpImgt[:,:,0])-np.min(tmpImgt[:,:,0]))
tmpImg[:,:,1] = (tmpImgt[:,:,1]-np.min(tmpImgt[:,:,1]))/(np.max(tmpImgt[:,:,1])-np.min(tmpImgt[:,:,1]))
tmpImg[:,:,2] = (tmpImgt[:,:,2]-np.min(tmpImgt[:,:,2]))/(np.max(tmpImgt[:,:,2])-np.min(tmpImgt[:,:,2]))
tmpImg[:,:,3] = (tmpImgtl[:,:,0]-np.min(tmpImgtl[:,:,0]))/(np.max(tmpImgtl[:,:,0])-np.min(tmpImgtl[:,:,0]))
tmpImg[:,:,4] = (tmpImgtl[:,:,1]-np.min(tmpImgtl[:,:,1]))/(np.max(tmpImgtl[:,:,1])-np.min(tmpImgtl[:,:,1]))
tmpImg[:,:,5] = (tmpImgtl[:,:,2]-np.min(tmpImgtl[:,:,2]))/(np.max(tmpImgtl[:,:,2])-np.min(tmpImgtl[:,:,2]))
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
tmpImg[:,:,3] = (tmpImg[:,:,3]-np.mean(tmpImg[:,:,3]))/np.std(tmpImg[:,:,3])
tmpImg[:,:,4] = (tmpImg[:,:,4]-np.mean(tmpImg[:,:,4]))/np.std(tmpImg[:,:,4])
tmpImg[:,:,5] = (tmpImg[:,:,5]-np.mean(tmpImg[:,:,5]))/np.std(tmpImg[:,:,5])
elif self.flag == 1: #with Lab color
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
if image.shape[2]==1:
tmpImg[:,:,0] = image[:,:,0]
tmpImg[:,:,1] = image[:,:,0]
tmpImg[:,:,2] = image[:,:,0]
else:
tmpImg = image
tmpImg = color.rgb2lab(tmpImg)
# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.min(tmpImg[:,:,0]))/(np.max(tmpImg[:,:,0])-np.min(tmpImg[:,:,0]))
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.min(tmpImg[:,:,1]))/(np.max(tmpImg[:,:,1])-np.min(tmpImg[:,:,1]))
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.min(tmpImg[:,:,2]))/(np.max(tmpImg[:,:,2])-np.min(tmpImg[:,:,2]))
tmpImg[:,:,0] = (tmpImg[:,:,0]-np.mean(tmpImg[:,:,0]))/np.std(tmpImg[:,:,0])
tmpImg[:,:,1] = (tmpImg[:,:,1]-np.mean(tmpImg[:,:,1]))/np.std(tmpImg[:,:,1])
tmpImg[:,:,2] = (tmpImg[:,:,2]-np.mean(tmpImg[:,:,2]))/np.std(tmpImg[:,:,2])
else: # with rgb color
tmpImg = np.zeros((image.shape[0],image.shape[1],3))
image = image/np.max(image)
if image.shape[2]==1:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,2] = (image[:,:,0]-0.485)/0.229
else:
tmpImg[:,:,0] = (image[:,:,0]-0.485)/0.229
tmpImg[:,:,1] = (image[:,:,1]-0.456)/0.224
tmpImg[:,:,2] = (image[:,:,2]-0.406)/0.225
tmpLbl[:,:,0] = label[:,:,0]
# change the r,g,b to b,r,g from [0,255] to [0,1]
#transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))
return {'imidx':torch.from_numpy(imidx.copy()), 'image': torch.from_numpy(tmpImg.copy()), 'label': torch.from_numpy(tmpLbl.copy())} #需要加.copy()
class SalObjDataset(Dataset):
def __init__(self,img_name_list,lbl_name_list,transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform
def __len__(self):
return len(self.image_name_list)
def __getitem__(self,idx):
# image = Image.open(self.image_name_list[idx])#io.imread(self.image_name_list[idx])
# label = Image.open(self.label_name_list[idx])#io.imread(self.label_name_list[idx])
image = io.imread(self.image_name_list[idx])
imname = self.image_name_list[idx]
imidx = np.array([idx])
if(0==len(self.label_name_list)):
label_3 = np.zeros(image.shape)
else:
label_3 = io.imread(self.label_name_list[idx])
label = np.zeros(label_3.shape[0:2])
if(3==len(label_3.shape)):
label = label_3[:,:,0]
elif(2==len(label_3.shape)):
label = label_3
if(3==len(image.shape) and 2==len(label.shape)):
label = label[:,:,np.newaxis]
elif(2==len(image.shape) and 2==len(label.shape)):
image = image[:,:,np.newaxis]
label = label[:,:,np.newaxis]
sample = {'imidx':imidx, 'image':image, 'label':label}
if self.transform:
sample = self.transform(sample)
return sample
训练代码基本不用改,主要是路径拼接能找到图像路径就好了
u2net_train.py
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms
import os
import numpy as np
import glob
from data_loader import Rescale
from data_loader import RescaleT
from data_loader import RandomCrop
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET
from model import U2NETP
# ------- 1. define loss function --------
bce_loss = nn.BCELoss(size_average=True)
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item())) #源码里这里有可能报错,如果报了改成和我一样应该就好了
return loss0, loss
# ------- 2. set the directory of training dataset --------
model_name = 'u2net' #'u2netp or u2net' #有两个模型可以选,选择的标记在下面训练开始前
data_dir = os.path.join(os.getcwd(), 'data' + os.sep)
# tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'im_aug' + os.sep)
# tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'DUTS-TR', 'gt_aug' + os.sep)
tra_image_dir = os.path.join('RIVER', 'Train', 'Image/') #训练图片的路径
tra_label_dir = os.path.join('RIVER', 'Train', 'Mask/')
image_ext = '.png' #图像后缀,改成自己的图片格式
label_ext = '.png'
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep) #模型存储位置
#训练相关参数
epoch_num = 2000
batch_size_train = 6
batch_size_val = 2
train_num = 0
val_num = 0
tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*' + image_ext)
tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split(os.sep)[-1]
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)
print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")
train_num = len(tra_img_name_list)
salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))
salobj_dataloader = DataLoader(salobj_dataset, batch_size=batch_size_train, shuffle=True, num_workers=0)
# ------- 3. define model --------
# define the net
#选择模型
if(model_name=='u2net'):
net = U2NET(3, 1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)
if torch.cuda.is_available():
net.cuda()
# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations 每迭代两千次存一次模型,这个可以在下面改成每个epoch存,很好改,这里我就不改了
for epoch in range(0, epoch_num):
net.train()
for i, data in enumerate(salobj_dataloader):
ite_num = ite_num + 1
ite_num4val = ite_num4val + 1
inputs, labels = data['image'], data['label']
inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)
# wrap them in Variable
if torch.cuda.is_available():
inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
requires_grad=False)
else:
inputs_v, labels_v = Variable(inputs, requires_grad=False), Variable(labels, requires_grad=False)
# y zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
loss.backward()
optimizer.step()
# # print statistics
running_loss += loss.item()
running_tar_loss += loss2.item()
# del temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss
print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
if ite_num % save_frq == 0:
torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
running_loss = 0.0
running_tar_loss = 0.0
net.train() # resume train
ite_num4val = 0
3.预测
u2net_test.py
import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim
import numpy as np
from PIL import Image
import glob
from data_loader import RescaleT
from data_loader import ToTensor
from data_loader import ToTensorLab
from data_loader import SalObjDataset
from model import U2NET # full size version 173.6 MB
from model import U2NETP # small version u2net 4.7 MB
# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)
dn = (d-mi)/(ma-mi)
return dn
def save_output(image_name,pred,d_dir):
predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()
im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split(os.sep)[-1]
image = io.imread(image_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)
pb_np = np.array(imo)
aaa = img_name.split(".")
bbb = aaa[0:-1]
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
imo.save(d_dir+imidx+'.png')
def main():
# --------- 1. get image path and name ---------
model_name='u2net'#u2netp #模型名字和训练一致
# image_dir = os.path.join(os.getcwd(), 'test_data', 'test_images')
image_dir = 'D:/wcs/U-2-Net/data/RIVER/Test/Image/' #测试图像路径
# prediction_dir = os.path.join(os.getcwd(), 'test_data', model_name + '_results' + os.sep)
prediction_dir = 'D:/wcs/U-2-Net/data/RIVER/Test/pre/' #保存结果路径
model_dir = os.path.join(os.getcwd(), 'saved_models', model_name, model_name + '.pth')
img_name_list = glob.glob(image_dir + os.sep + '*')
# print(img_name_list)
# --------- 2. dataloader ---------
#1. dataloader
test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
test_salobj_dataloader = DataLoader(test_salobj_dataset,
batch_size=1,
shuffle=False,
num_workers=1)
# --------- 3. model define ---------
if(model_name=='u2net'):
print("...load U2NET---173.6 MB")
net = U2NET(3,1)
elif(model_name=='u2netp'):
print("...load U2NEP---4.7 MB")
net = U2NETP(3,1)
# net.load_state_dict(torch.load(model_dir))
net.load_state_dict(torch.load('./saved_models/u2net/u2net_bce_itr_36000_train_0.091362_tar_0.003286.pth')) #加载自己的模型
if torch.cuda.is_available():
net.cuda()
net.eval()
# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_salobj_dataloader):
print("inferencing:",img_name_list[i_test].split(os.sep)[-1])
inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)
if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)
d1,d2,d3,d4,d5,d6,d7= net(inputs_test)
# normalization
# print(d7.shape)
pred = d1[:,0,:,:] #注意这里,这个d1是融合了d2,d3,d4,d5,d6,d7的,如果想了解具体就翻到网络模型去自习看看
pred = normPRED(pred)
# save results to test_results folder
if not os.path.exists(prediction_dir):
os.makedirs(prediction_dir, exist_ok=True)
save_output(img_name_list[i_test],pred,prediction_dir)
del d1,d2,d3,d4,d5,d6,d7
if __name__ == "__main__":
main()
另外,这个项目在GitHub上展示的人像轮廓提取效果非常好,说明模型关注细节的能力很强,建议需要提取线的相关任务也做做尝试