语义分割和实例分割概述
分割任务就是在原始图像中逐像素的找到。
语义分割就是把每个像素都打上标签(这个像素是人、树、背景等)语义分割只区分类别,不区分类别中具体单位。
实例分割不光要其区分类别,还要区分类别中每一个个体。
逐像素进行判断(类似于二分类属于前景还是背景,或者多分类)
分割任务中的目标函数定义
损失函数
逐像素的交叉熵,交叉熵损失函数公式如下。pos_weight是权重项,因为正负样本比例不均衡一次要考虑此问题,增加权重参数来平衡。
Focal loss
样本也有难易成分,样本谁少谁重要,难易点可以理解为轮廓上的像素点难以区分。
下公式的Gamma值通常设置为2,例如预测正样本概率0.95,(1-0.95)^2=0.0025,简单样本权值小表示对损失函数影响不大。如果预测正样本概率为0.5,(1-0.5)^2=0.25,表示难样本的权值大,要对损失函数影响大。可以区分样本的难易程度,这里的样本区分就是通过权值来表现的。
再结合样本数量的权值得到最终的损失函数,这里的a就是样本数量的权值。
MIOU评估标准
IOU计算
多分类任务时,如下图计算dog的IOU值。iou_dog=801/true_dog+predict_dog-801。分母就是黄色加绿色全部减去重复值。
IoU交并比,MIOU就是计算所有类别的平均值,一般是分割任务评估指标。
U-net语义分割网络
整体结构
编码将输入数据做成一个特征,输出大小一定要和输入大小一致,因为要逐像素分析。编码和解码进行融合。引入了特征拼接操作,之前是加法,现在是浅层和深层进行拼接,所有信息都拼接。
U-net++
整体网络结构,特征更融合,拼接更全面。
更容易剪枝,因为前面也单独又监督训练,可以根据速度要求来快速完成剪枝。
U-net+++
不同的max pool整合低阶特征。
unet医学分割实战源码分析
数据增强模块
使用albumentations模块,可以在github中找到对应模块,数据增强中对应的标签也会更改
代码调试
调试代码时我们首先要找到入口函数也就是train.py,找到对应的主函数main(),首先进行参数初始化设置,通过config传递函数,首先读取数据,随后进行如下的数据增强。再将数据分别通过trainl_loader和val_loader实现数据集batch的导入。
网络计算流程
首先数据处理流程,在给出代码的dataset.py文件中找Dataset类中的__getitem__函数。首先读取数据以及标签,进行数据增强。
开始进行epoch的训练,随着轮次进行训练(main函数)
进入train函数
model.trian()开始训练,通过配置文件的网络结构进行训练,此时看的是archs.py文件。如下图包含了对应的网络模块。其中NestedUNet表示Unet++,在每个网络架构模块中的forward函数打上断点,进入debug阶段。
首先进入NestedUNet函数,为了方便理解代码实现效果,在每一层中输出对应特征图的维度,首先输入图像为(8,3,96,96),继续执行self.conv0_0(),查看对应的构造函数,如下图所示表示其函数为VGGBlock,找到对应的函数同样在forward函数中打上断点,继续执行。
进入VGGBlock函数,可以看到第一个VGGBlock的输入维度是3通道表示最初输入图像的RGB,nb_filter[0]为32表示输出32维度特征图。如下图显示VGGBlock的工作原理,可以发现是进行了两次卷积操作。
为了方便解释unet++的forward流程将代码粘贴如下。代码原理和原理图的原理一致,首先分别经过两个VGGBlock得到96维度以及46维度的特征图如下图圈所示,之后将48维度特征图通过内置的上采样函数再和96维度进行拼接,通过conv0_1得到32维度特征。经过如下图的特征融合最后得到输出值,输出类似于连接全连接层得到(8,1,96,96)此时的1表示类别。还可以通过self.deep_supervision参数得到每个层的预测值。
class NestedUNet(nn.Module):
def __init__(self, num_classes, input_channels=3, deep_supervision=False, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512]
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool2d(2, 2)
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
def forward(self, input):
print('input:',input.shape)
x0_0 = self.conv0_0(input)
print('x0_0:',x0_0.shape)
x1_0 = self.conv1_0(self.pool(x0_0)) # 先下采样再进行VGGBLocK
print('x1_0:',x1_0.shape)
x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], 1)) #拼接
print('x0_1:',x0_1.shape)
x2_0 = self.conv2_0(self.pool(x1_0))
print('x2_0:',x2_0.shape)
x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], 1))
print('x1_1:',x1_1.shape)
x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], 1))
print('x0_2:',x0_2.shape)
x3_0 = self.conv3_0(self.pool(x2_0))
print('x3_0:',x3_0.shape)
x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], 1))
print('x2_1:',x2_1.shape)
x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], 1))
print('x1_2:',x1_2.shape)
x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], 1))
print('x0_3:',x0_3.shape)
x4_0 = self.conv4_0(self.pool(x3_0))
print('x4_0:',x4_0.shape)
x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], 1))
print('x3_1:',x3_1.shape)
x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], 1))
print('x2_2:',x2_2.shape)
x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], 1))
print('x1_3:',x1_3.shape)
x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], 1))
print('x0_4:',x0_4.shape)
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
继续运行到trian.py,得到输出值,进行损失函数计算利用交叉熵计算,以及IOU可计算,梯度清0,反向传播,梯度优化,记录损失和IOU。
训练完模型后进行模型效果验证,训练得到的模型文件在models/dsb2018_96_NestedUNet_woDS/model.pth。通过val.py进行模型验证,如下所示。
import argparse
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import archs
from dataset import Dataset
from metrics import iou_score
from utils import AverageMeter
"""
需要指定参数:--name dsb2018_96_NestedUNet_woDS
"""
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None,
help='model name')
args = parser.parse_args()
return args
def main():
args = parse_args()
with open('models/%s/config.yml' % args.name, 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print('-'*20)
for key in config.keys():
print('%s: %s' % (key, str(config[key])))
print('-'*20)
cudnn.benchmark = True
# create model
print("=> creating model %s" % config['arch'])
model = archs.__dict__[config['arch']](config['num_classes'],
config['input_channels'],
config['deep_supervision'])
model = model.cuda()
# Data loading code
img_ids = glob(os.path.join('inputs', config['dataset'], 'images', '*' + config['img_ext']))
img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=41)
model.load_state_dict(torch.load('models/%s/model.pth' %
config['name']))
model.eval()
val_transform = Compose([
transforms.Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
val_dataset = Dataset(
img_ids=val_img_ids,
img_dir=os.path.join('inputs', config['dataset'], 'images'),
mask_dir=os.path.join('inputs', config['dataset'], 'masks'),
img_ext=config['img_ext'],
mask_ext=config['mask_ext'],
num_classes=config['num_classes'],
transform=val_transform)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers'],
drop_last=False)
avg_meter = AverageMeter()
for c in range(config['num_classes']):
os.makedirs(os.path.join('outputs', config['name'], str(c)), exist_ok=True)
with torch.no_grad():
for input, target, meta in tqdm(val_loader, total=len(val_loader)):
input = input.cuda()
target = target.cuda()
# compute output
if config['deep_supervision']:
output = model(input)[-1]
else:
output = model(input)
iou = iou_score(output, target)
avg_meter.update(iou, input.size(0))
output = torch.sigmoid(output).cpu().numpy()
for i in range(len(output)):
for c in range(config['num_classes']):
cv2.imwrite(os.path.join('outputs', config['name'], str(c), meta['img_id'][i] + '.jpg'),
(output[i, c] * 255).astype('uint8'))
print('IoU: %.4f' % avg_meter.avg)
plot_examples(input, target, model,num_examples=3)
torch.cuda.empty_cache()
def plot_examples(datax, datay, model,num_examples=6):
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = np.random.randint(m)
image_arr = model(datax[image_indx:image_indx+1]).squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.40)[0,:,:].astype(int)))
ax[row_num][1].set_title("Segmented Image localization")
ax[row_num][2].imshow(np.transpose(datay[image_indx].cpu().numpy(), (1,2,0))[:,:,0])
ax[row_num][2].set_title("Target image")
plt.show()
if __name__ == '__main__':
main()
通过验证得到对应的效果图如下图所示。分别展示了原始图像、分割图像、标签图像。