# DeepLearing—CV系列（十九）——图像分割之U^2-Net（效果极好）的Pytorch实现

### 文章目录

（1）普通卷积之后特征图计算公式：
w=h=（n+2p-k）/s+1
（2）进行空洞卷积后的计算公式：

（3）进行池化后的特征图计算公式：
（n-k）/s+1

# 一、u2net.py

EN_1：2次卷积、5次下采样、1个空洞卷积、5层上采样（每次差值上采样都需要做一次卷积）。

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F

class REBNCONV(nn.Module):# 卷积块
def __init__(self,in_ch=3,out_ch=3,dirate=1):# 膨胀率
super(REBNCONV,self).__init__()

self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)

def forward(self,x):

hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))

return xout

## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):

src = F.upsample(src,size=tar.shape[2:],mode='bilinear')

return src

### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7,self).__init__()

self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)

self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)

self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

def forward(self,x):

hx = x
hxin = self.rebnconvin(hx)

hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)

hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)

hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)

hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)

hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)

hx6 = self.rebnconv6(hx)

hx7 = self.rebnconv7(hx6)

hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))
hx6dup = _upsample_like(hx6d,hx5)

hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)

hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)

hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

return hx1d + hxin

### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6,self).__init__()

self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)

self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)

self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

def forward(self,x):

hx = x

hxin = self.rebnconvin(hx)

hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)

hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)

hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)

hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)

hx5 = self.rebnconv5(hx)

hx6 = self.rebnconv6(hx5)

hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)

hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)

hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

return hx1d + hxin

### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5,self).__init__()

self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)

self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)

self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

def forward(self,x):

hx = x

hxin = self.rebnconvin(hx)

hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)

hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)

hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)

hx4 = self.rebnconv4(hx)

hx5 = self.rebnconv5(hx4)

hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)

hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

return hx1d + hxin

### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4,self).__init__()

self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)
self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)

self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)

self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

def forward(self,x):

hx = x

hxin = self.rebnconvin(hx)

hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)

hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)

hx3 = self.rebnconv3(hx)

hx4 = self.rebnconv4(hx3)

hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))

return hx1d + hxin

### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):

def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F,self).__init__()

self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)

self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)
self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)
self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)

self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)

self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)
self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)
self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)

def forward(self,x):

hx = x

hxin = self.rebnconvin(hx)

hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)

hx4 = self.rebnconv4(hx3)

hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))
hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))
hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))

return hx1d + hxin

##### U^2-Net ####
class U2NET(nn.Module):

def __init__(self,in_ch=3,out_ch=1):
super(U2NET,self).__init__()

self.stage1 = RSU7(in_ch,32,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage2 = RSU6(64,32,128)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage3 = RSU5(128,64,256)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage4 = RSU4(256,128,512)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage5 = RSU4F(512,256,512)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage6 = RSU4F(512,256,512)

# decoder
self.stage5d = RSU4F(1024,256,512)
self.stage4d = RSU4(1024,128,256)
self.stage3d = RSU5(512,64,128)
self.stage2d = RSU6(256,32,64)
self.stage1d = RSU7(128,16,64)

self.outconv = nn.Conv2d(6,out_ch,1)

def forward(self,x):

hx = x

#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)

#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)

#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)

#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)

#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)

#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)

#-------------------- decoder --------------------
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)

hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)

hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))

#side output
d1 = self.side1(hx1d)

d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)

d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)

d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)

d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)

d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)

d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

### U^2-Net small ###
class U2NETP(nn.Module):

def __init__(self,in_ch=3,out_ch=1):
super(U2NETP,self).__init__()

self.stage1 = RSU7(in_ch,16,64)
self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage2 = RSU6(64,16,64)
self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage3 = RSU5(64,16,64)
self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage4 = RSU4(64,16,64)
self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage5 = RSU4F(64,16,64)
self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)

self.stage6 = RSU4F(64,16,64)

# decoder
self.stage5d = RSU4F(128,16,64)
self.stage4d = RSU4(128,16,64)
self.stage3d = RSU5(128,16,64)
self.stage2d = RSU6(128,16,64)
self.stage1d = RSU7(128,16,64)

self.outconv = nn.Conv2d(6,out_ch,1)

def forward(self,x):

hx = x

#stage 1
hx1 = self.stage1(hx)
hx = self.pool12(hx1)

#stage 2
hx2 = self.stage2(hx)
hx = self.pool23(hx2)

#stage 3
hx3 = self.stage3(hx)
hx = self.pool34(hx3)

#stage 4
hx4 = self.stage4(hx)
hx = self.pool45(hx4)

#stage 5
hx5 = self.stage5(hx)
hx = self.pool56(hx5)

#stage 6
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6,hx5)

#decoder
hx5d = self.stage5d(torch.cat((hx6up,hx5),1))
hx5dup = _upsample_like(hx5d,hx4)

hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))
hx4dup = _upsample_like(hx4d,hx3)

hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))
hx3dup = _upsample_like(hx3d,hx2)

hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))
hx2dup = _upsample_like(hx2d,hx1)

hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))

#side output
d1 = self.side1(hx1d)

d2 = self.side2(hx2d)
d2 = _upsample_like(d2,d1)

d3 = self.side3(hx3d)
d3 = _upsample_like(d3,d1)

d4 = self.side4(hx4d)
d4 = _upsample_like(d4,d1)

d5 = self.side5(hx5d)
d5 = _upsample_like(d5,d1)

d6 = self.side6(hx6)
d6 = _upsample_like(d6,d1)

d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))

return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)



# data loader
from __future__ import print_function, division
import glob
import torch
from skimage import io, transform, color
import numpy as np
import random
import math
import matplotlib.pyplot as plt
from torchvision import transforms, utils
from PIL import Image

class RescaleT(object):

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size

def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']

h, w = image.shape[:2]

if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size

new_h, new_w = int(new_h), int(new_w)

# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
# img = transform.resize(image,(new_h,new_w),mode='constant')
# lbl = transform.resize(label,(new_h,new_w),mode='constant', order=0, preserve_range=True)

img = transform.resize(image, (self.output_size, self.output_size), mode='constant')
lbl = transform.resize(label, (self.output_size, self.output_size), mode='constant', order=0,
preserve_range=True)

return {'imidx': imidx, 'image': img, 'label': lbl}

class Rescale(object):

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
self.output_size = output_size

def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']

if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]

h, w = image.shape[:2]

if isinstance(self.output_size, int):
if h > w:
new_h, new_w = self.output_size * h / w, self.output_size
else:
new_h, new_w = self.output_size, self.output_size * w / h
else:
new_h, new_w = self.output_size

new_h, new_w = int(new_h), int(new_w)

# #resize the image to new_h x new_w and convert image from range [0,255] to [0,1]
img = transform.resize(image, (new_h, new_w), mode='constant')
lbl = transform.resize(label, (new_h, new_w), mode='constant', order=0, preserve_range=True)

return {'imidx': imidx, 'image': img, 'label': lbl}

class RandomCrop(object):

def __init__(self, output_size):
assert isinstance(output_size, (int, tuple))
if isinstance(output_size, int):
self.output_size = (output_size, output_size)
else:
assert len(output_size) == 2
self.output_size = output_size

def __call__(self, sample):
imidx, image, label = sample['imidx'], sample['image'], sample['label']

if random.random() >= 0.5:
image = image[::-1]
label = label[::-1]

h, w = image.shape[:2]
new_h, new_w = self.output_size

top = np.random.randint(0, h - new_h)
left = np.random.randint(0, w - new_w)

image = image[top: top + new_h, left: left + new_w]
label = label[top: top + new_h, left: left + new_w]

return {'imidx': imidx, 'image': image, 'label': label}

class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""

def __call__(self, sample):

imidx, image, label = sample['imidx'], sample['image'], sample['label']

tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
tmpLbl = np.zeros(label.shape)

image = image / np.max(image)
if (np.max(label) < 1e-6):
label = label
else:
label = label / np.max(label)

if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225

tmpLbl[:, :, 0] = label[:, :, 0]

# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))

return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}

class ToTensorLab(object):
"""Convert ndarrays in sample to Tensors."""

def __init__(self, flag=0):
self.flag = flag

def __call__(self, sample):

imidx, image, label = sample['imidx'], sample['image'], sample['label']

tmpLbl = np.zeros(label.shape)

if (np.max(label) < 1e-6):
label = label
else:
label = label / np.max(label)

# change the color space
if self.flag == 2:  # with rgb and Lab colors
tmpImg = np.zeros((image.shape[0], image.shape[1], 6))
tmpImgt = np.zeros((image.shape[0], image.shape[1], 3))
if image.shape[2] == 1:
tmpImgt[:, :, 0] = image[:, :, 0]
tmpImgt[:, :, 1] = image[:, :, 0]
tmpImgt[:, :, 2] = image[:, :, 0]
else:
tmpImgt = image
tmpImgtl = color.rgb2lab(tmpImgt)

# nomalize image to range [0,1]
tmpImg[:, :, 0] = (tmpImgt[:, :, 0] - np.min(tmpImgt[:, :, 0])) / (
np.max(tmpImgt[:, :, 0]) - np.min(tmpImgt[:, :, 0]))
tmpImg[:, :, 1] = (tmpImgt[:, :, 1] - np.min(tmpImgt[:, :, 1])) / (
np.max(tmpImgt[:, :, 1]) - np.min(tmpImgt[:, :, 1]))
tmpImg[:, :, 2] = (tmpImgt[:, :, 2] - np.min(tmpImgt[:, :, 2])) / (
np.max(tmpImgt[:, :, 2]) - np.min(tmpImgt[:, :, 2]))
tmpImg[:, :, 3] = (tmpImgtl[:, :, 0] - np.min(tmpImgtl[:, :, 0])) / (
np.max(tmpImgtl[:, :, 0]) - np.min(tmpImgtl[:, :, 0]))
tmpImg[:, :, 4] = (tmpImgtl[:, :, 1] - np.min(tmpImgtl[:, :, 1])) / (
np.max(tmpImgtl[:, :, 1]) - np.min(tmpImgtl[:, :, 1]))
tmpImg[:, :, 5] = (tmpImgtl[:, :, 2] - np.min(tmpImgtl[:, :, 2])) / (
np.max(tmpImgtl[:, :, 2]) - np.min(tmpImgtl[:, :, 2]))

# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))

tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])
tmpImg[:, :, 3] = (tmpImg[:, :, 3] - np.mean(tmpImg[:, :, 3])) / np.std(tmpImg[:, :, 3])
tmpImg[:, :, 4] = (tmpImg[:, :, 4] - np.mean(tmpImg[:, :, 4])) / np.std(tmpImg[:, :, 4])
tmpImg[:, :, 5] = (tmpImg[:, :, 5] - np.mean(tmpImg[:, :, 5])) / np.std(tmpImg[:, :, 5])

elif self.flag == 1:  # with Lab color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))

if image.shape[2] == 1:
tmpImg[:, :, 0] = image[:, :, 0]
tmpImg[:, :, 1] = image[:, :, 0]
tmpImg[:, :, 2] = image[:, :, 0]
else:
tmpImg = image

tmpImg = color.rgb2lab(tmpImg)

# tmpImg = tmpImg/(np.max(tmpImg)-np.min(tmpImg))

tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.min(tmpImg[:, :, 0])) / (
np.max(tmpImg[:, :, 0]) - np.min(tmpImg[:, :, 0]))
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.min(tmpImg[:, :, 1])) / (
np.max(tmpImg[:, :, 1]) - np.min(tmpImg[:, :, 1]))
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.min(tmpImg[:, :, 2])) / (
np.max(tmpImg[:, :, 2]) - np.min(tmpImg[:, :, 2]))

tmpImg[:, :, 0] = (tmpImg[:, :, 0] - np.mean(tmpImg[:, :, 0])) / np.std(tmpImg[:, :, 0])
tmpImg[:, :, 1] = (tmpImg[:, :, 1] - np.mean(tmpImg[:, :, 1])) / np.std(tmpImg[:, :, 1])
tmpImg[:, :, 2] = (tmpImg[:, :, 2] - np.mean(tmpImg[:, :, 2])) / np.std(tmpImg[:, :, 2])

else:  # with rgb color
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225

tmpLbl[:, :, 0] = label[:, :, 0]

# change the r,g,b to b,r,g from [0,255] to [0,1]
# transforms.Normalize(mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225))
tmpImg = tmpImg.transpose((2, 0, 1))
tmpLbl = label.transpose((2, 0, 1))

return {'imidx': torch.from_numpy(imidx), 'image': torch.from_numpy(tmpImg), 'label': torch.from_numpy(tmpLbl)}

class SalObjDataset(Dataset):
def __init__(self, img_name_list, lbl_name_list, transform=None):
# self.root_dir = root_dir
# self.image_name_list = glob.glob(image_dir+'*.png')
# self.label_name_list = glob.glob(label_dir+'*.png')
self.image_name_list = img_name_list
self.label_name_list = lbl_name_list
self.transform = transform

def __len__(self):
return len(self.image_name_list)

def __getitem__(self, idx):

imname = self.image_name_list[idx]
imidx = np.array([idx])

if (0 == len(self.label_name_list)):
label_3 = np.zeros(image.shape)
else:

label = np.zeros(label_3.shape[0:2])
if (3 == len(label_3.shape)):
label = label_3[:, :, 0]
elif (2 == len(label_3.shape)):
label = label_3

if (3 == len(image.shape) and 2 == len(label.shape)):
label = label[:, :, np.newaxis]
elif (2 == len(image.shape) and 2 == len(label.shape)):
image = image[:, :, np.newaxis]
label = label[:, :, np.newaxis]

sample = {'imidx': imidx, 'image': image, 'label': label}

if self.transform:
sample = self.transform(sample)

return sample



# 三、train.py

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torchvision import transforms, utils
import torch.optim as optim
import torchvision.transforms as standard_transforms

import numpy as np
import glob

from u2net import U2NET
from u2net import U2NETP

# ------- 1. define loss function --------

bce_loss = nn.BCELoss(size_average=True)

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):

loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)

loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n"%(loss0.item(),loss1.item(),loss2.item(),loss3.item(),loss4.item(),loss5.item(),loss6.item()))

return loss0, loss

def main():
# ------- 2. set the directory of training dataset --------
model_name = 'u2net' #'u2netp'

data_dir = 'C:\datasets'
tra_image_dir = '\DUTS\DUTS-TR\DUTS-TR-Image\\'

image_ext = '.jpg'
label_ext = '.png'

model_dir = './saved_models/' + model_name +'/'

epoch_num = 100000
batch_size_train = 4
batch_size_val = 1
train_num = 0
val_num = 0

tra_img_name_list = glob.glob(data_dir + tra_image_dir + '*')
print(tra_img_name_list)

tra_lbl_name_list = []
for img_path in tra_img_name_list:
img_name = img_path.split("\\")[-1]

aaa = img_name.split(".")
bbb = aaa[0:-1]
#去除后缀的图片名
imidx = bbb[0]
# print(imidx)
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]
print(imidx,"**********")

tra_lbl_name_list.append(data_dir + tra_label_dir + imidx + label_ext)

print("---")
print("train images: ", len(tra_img_name_list))
print("train labels: ", len(tra_lbl_name_list))
print("---")

train_num = len(tra_img_name_list)

salobj_dataset = SalObjDataset(
img_name_list=tra_img_name_list,
lbl_name_list=tra_lbl_name_list,
transform=transforms.Compose([
RescaleT(320),
RandomCrop(288),
ToTensorLab(flag=0)]))

# ------- 3. define model --------
# define the net
if(model_name=='u2net'):
net = U2NET(3, 1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)

if torch.cuda.is_available():
net.cuda()

# ------- 4. define optimizer --------
print("---define optimizer...")
optimizer = optim.Adam(net.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
net.pool34.parameters()
# ------- 5. training process --------
print("---start training...")
ite_num = 0
running_loss = 0.0
running_tar_loss = 0.0
ite_num4val = 0
save_frq = 2000 # save the model every 2000 iterations

for epoch in range(0, epoch_num):
net.train()

ite_num = ite_num + 1
ite_num4val = ite_num4val + 1

inputs, labels = data['image'], data['label']

inputs = inputs.type(torch.FloatTensor)
labels = labels.type(torch.FloatTensor)

# wrap them in Variable
if torch.cuda.is_available():

inputs_v, labels_v = Variable(inputs.cuda(), requires_grad=False), Variable(labels.cuda(),
else:

# y zero the parameter gradients

# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)

loss.backward()
optimizer.step()

# # print statistics
running_loss += loss.item()
running_tar_loss += loss2.item()

# delete temporary outputs and loss
del d0, d1, d2, d3, d4, d5, d6, loss2, loss

print("[epoch: %3d/%3d, batch: %5d/%5d, ite: %d] train loss: %3f, tar: %3f " % (
epoch + 1, epoch_num, (i + 1) * batch_size_train, train_num, ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))

if ite_num % save_frq == 0:

# torch.save(net.state_dict(), model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
running_loss = 0.0
running_tar_loss = 0.0
net.train()  # resume train
ite_num4val = 0

if __name__ == "__main__":
main()



# 四、test.py

import os
from skimage import io, transform
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from u2net import U2NET # full size version 173.6 MB
from u2net import U2NETP # small version u2net 4.7 MB

# normalize the predicted SOD probability map
def normPRED(d):
ma = torch.max(d)
mi = torch.min(d)

dn = (d-mi)/(ma-mi)

return dn

def save_output(image_name,pred,d_dir):

predict = pred
predict = predict.squeeze()
predict_np = predict.cpu().data.numpy()

im = Image.fromarray(predict_np*255).convert('RGB')
img_name = image_name.split("\\")[-1]
# print(image_name)
# print(img_name)
imo = im.resize((image.shape[1],image.shape[0]),resample=Image.BILINEAR)

pb_np = np.array(imo)

aaa = img_name.split(".")
bbb = aaa[0:-1]
# print(aaa)
# print(bbb)
imidx = bbb[0]
for i in range(1,len(bbb)):
imidx = imidx + "." + bbb[i]

imo.save(d_dir+imidx+'.png')

def main():

# --------- 1. get image path and name ---------
model_name='u2net'#u2netp

image_dir = './test_data/test_images/'
prediction_dir = './test_data/' + model_name + '_results/'
model_dir = './saved_models/'+ model_name + '/' + model_name + '.pth'

img_name_list = glob.glob(image_dir + '*')
print(img_name_list)

test_salobj_dataset = SalObjDataset(img_name_list = img_name_list,
lbl_name_list = [],
transform=transforms.Compose([RescaleT(320),
ToTensorLab(flag=0)])
)
batch_size=1,
shuffle=False,
num_workers=1)

# --------- 3. model define ---------
if(model_name=='u2net'):
net = U2NET(3,1)
elif(model_name=='u2netp'):
net = U2NETP(3,1)
if torch.cuda.is_available():
net.cuda()
net.eval()

# --------- 4. inference for each image ---------

print("inferencing:",img_name_list[i_test].split("/")[-1])

inputs_test = data_test['image']
inputs_test = inputs_test.type(torch.FloatTensor)

if torch.cuda.is_available():
inputs_test = Variable(inputs_test.cuda())
else:
inputs_test = Variable(inputs_test)

d1,d2,d3,d4,d5,d6,d7= net(inputs_test)

# normalization
pred = d1[:,0,:,:]
pred = normPRED(pred)

# save results to test_results folder
save_output(img_name_list[i_test],pred,prediction_dir)

del d1,d2,d3,d4,d5,d6,d7

if __name__ == "__main__":
main()



# 五、crop.py

# -*- coding: utf-8 -*-

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# name, *_ = img_file.split(".")
img_array = np.array(Image.open(img_file))

# print(res.shape)
res = np.concatenate((img_array, mask[:, :, [0]]), -1)
img = Image.fromarray(res.astype('uint8'), mode='RGBA')
# img.show()
return img

if __name__ == "__main__":
import os

model = "u2net"
# model = "u2netp"

img_root = "test_data/test_images"
crop_root = "test_data/{}_crops".format(model)
os.makedirs(crop_root, mode=0o775, exist_ok=True)

for img_file in os.listdir(img_root):
print("crop image {}".format(img_file))
name, *_ = img_file.split(".")
res = crop(
os.path.join(img_root,  img_file),
)
res.save(os.path.join(crop_root, name + "_crop.png"))
# exit()



# 六、效果展示

• 点赞 1
• 评论 5
• 分享
x

海报分享

扫一扫，分享海报

• 收藏 5
• 手机看

分享到微信朋友圈

x

扫一扫，手机阅读

• 打赏

打赏

wa1tzy

你的鼓励将是我创作的最大动力

C币 余额
2C币 4C币 6C币 10C币 20C币 50C币
• 一键三连

点赞Mark关注该博主, 随时了解TA的最新博文

06-18 93
10-17