引言
上一篇中,和大家一起做了一个细胞分类的小模型,相信大家都已经可以自己搭一个简单的图像分类框架。
在这一篇中,和大家一起做一个图像分割的模型。共勉!
医学专业知识
这次还是从一个大项目中扣出来的一个小功能模块——骨髓腔分割,
同样的,这部分知识 是同事辛苦整理的,望珍惜!
骨骼的基本结构:
骨组织、骨髓、骨膜、椎间关节、周围结缔组织或肌肉
我们的目的就是要通过图像分割得到骨髓腔的区域。
数据集
数据集是公司内部的数据,不能公开。
大家如果没有现成的数据的话,可以用画图工具标几张。
图像和mask要一一对应。
模型
可以使用经典的unet、u2net、cenet做实验,
后面可以复现较新的分割论文,以到达更好的效果,学霸可以自己创新。
本篇就使用经典的u2net。
模型测试结果
最后结果整体还算是可以的,iou也很高,
但是细看的话,只能说边缘区域马马虎虎。
代码
mydataset.py
import cv2
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import sys
ia.seed(16)
class ToTensor(object):
"""Convert ndarrays in sample to Tensors."""
def __call__(self, image, label=None):
tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
image = image / np.max(image)
if image.shape[2] == 1:
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
else: # bgr
tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
tmpImg = tmpImg.transpose((2, 0, 1))
if label is not None:
tmpLbl = np.zeros((label.shape[0], label.shape[0], 1))
label = label / 255
tmpLbl[:, :, 0] = label
tmpLbl = tmpLbl.transpose((2, 0, 1))
return torch.from_numpy(tmpImg), torch.from_numpy(tmpLbl)
else:
return torch.from_numpy(tmpImg), None
class Dataset_imgaug(Dataset):
def __init__(self, file_name, path_img, path_label, is_use_imgaug=True, img_size=320, transform=None):
self.is_use_imgaug = is_use_imgaug
self.file_name = file_name
self.path_img = path_img
self.path_label = path_label
self.img_size = img_size
self.ToTensor = ToTensor()
if is_use_imgaug:
self.transform = iaa.Sequential([
iaa.Fliplr(p=0.5), # 水平镜面翻转
iaa.Flipud(p=0.5), # 上下镜面翻转
iaa.SomeOf((0, 5), # 代表每次从中选择0~5个方法增强图像
[iaa.Multiply(),
iaa.Sharpen(), # 图像锐化
iaa.contrast.GammaContrast(), # 伽马对比度
iaa.imgcorruptlike.Brightness(), # 调整图像亮度
iaa.ElasticTransformation(),
iaa.Affine( # 对一部分图像做仿射变换
scale={"x": (0.8, 1.2), "y": (0.8, 1.2)}, # 图像缩放为80%到120%之间
translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)}, # 平移±20%之间
rotate=(-45, 45), # 旋转±45度之间
shear=(-16, 16), # 剪切变换±16度,(矩形变平行四边形)
order=[0, 1], # 使用最邻近差值或者双线性差值
cval=0, # 全黑填充"constant"
mode="constant"), # mode=ia.ALL #定义填充图像外区域的方法
iaa.CropAndPad(
px=(-80, 80),
pad_cval=0,
pad_mode="constant",
keep_size=True,
sample_independently=False),
]),
iaa.Resize(img_size),
])
else:
if transform is not None:
self.transform = transform
else:
self.transform = self.ToTensor
def __len__(self):
return len(self.file_name)
def __getitem__(self, item):
_name = self.file_name[item]
_img = cv2.imread(os.path.join(self.path_img, _name))
if self.path_label is not None:
_label = cv2.imread(os.path.join(self.path_label, _name), 0)
else:
_label =None
if self.is_use_imgaug:
_label = SegmentationMapsOnImage(_label, shape=_img.shape)
_img, _label = self.transform(image=_img, segmentation_maps=_label)
_label = _label.get_arr()
_img, _label = self.ToTensor(_img, _label)
else:
_img, _label = self.transform(_img, _label)
#TODO 待优化 不同的transform需要不同的输入
if _label is not None:
return _img.float(), _label.float()
else:
return _img.float(),None
train.py
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score
from dataset_imgaug import Dataset_imgaug
import argparse
import yaml
from u2net import U2NET
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
from tool.tool_gzz import *
from utils import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='config.yaml', type=str)
args = parser.parse_args()
# print(args.config)
# 获取参数
print('load parser')
with open(args.config, errors='ignore') as f:
config = yaml.safe_load(f)
path_imgs = config.get('path_imgs')
path_labels = config.get('path_labels')
batch_size = config.get('batch_size')
epochs = config.get('epochs')
lr = config.get('lr')
in_channel = config.get('in_channel')
out_channel = config.get('out_channel')
image_size = config.get('image_size')
best_iou = config.get('best_iou')
checkpoint = config.get('checkpoint')
warmup_step = config.get('warmup_step')
cha = (lr - 0.000001) / warmup_step
# 划分数据集
imgs_list = os.listdir(path_imgs)
# labels_list = os.listdir(path_labels)
train_data, test_data = train_test_split(imgs_list, test_size=0.2, random_state=16)
train_dataset = Dataset_imgaug(train_data, path_imgs, path_labels, is_use_imgaug=False)
test_dataset = Dataset_imgaug(test_data, path_imgs, path_labels, is_use_imgaug=False)
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=0)
# 声明模型
print('creat mode')
net = U2NET(in_channel, out_channel).to(device)
optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
criterion = nn.BCELoss(size_average=True).to(device) # 二元交叉熵
train_loss, test_loss = [], []
train_iou, test_iou = [], []
for epoch in range(epochs):
start_time = time.time()
# 训练
net.train()
run_loss, iou = 0.0, 0.0
# 手动调整学习率
# if epoch <= warmup_step:
# _lr = 0.000001 + epoch * cha
# for param_group in optimizer.param_groups:
# param_group['lr'] = _lr
for i, (data, label) in enumerate(train_dataloader):
data = data.to(device)
label = label.to(device)
optimizer.zero_grad()
d0, d1, d2, d3, d4, d5, d6 = net(data)
loss0, loss = Muti_BCELOSS(criterion, d0, d1, d2, d3, d4, d5, d6, label)
loss.backward()
optimizer.step()
run_loss = run_loss + loss.item()
pred = d1[:, 0, :, :]
pred[pred.ge(0.5)] = 1 # gt/lt/ge/le/eq/ne 大于/小于/大于等于/小于等于/等于/不等于
pred[pred.lt(0.5)] = 0
confusionMatrix = ConfusionMatrix(2, pred.squeeze().data, label.squeeze().data)
IOU = IntersectionOverUnion(confusionMatrix)[-1]
iou += IOU.cpu().data
# print('train: {epoch:%d}: loss=%3.3f' % (epoch, run_loss / (i + 1)))
if i % 20 == 19:
print('train: {epoch:%d %d/%d}: loss=%3.3f iou=%3.3f' % (epoch,i,len(train_dataloader), run_loss / (i + 1), iou / (i + 1)))
train_loss.append(run_loss / len(train_dataloader))
train_iou.append(iou / len(train_dataloader))
# 测试
net.eval()
with torch.no_grad():
run_loss, iou = 0.0, 0.0
for i, (data, label) in enumerate(test_dataloader):
data = data.to(device)
label = label.to(device)
d0, d1, d2, d3, d4, d5, d6 = net(data)
loss0, loss = Muti_BCELOSS(criterion, d0, d1, d2, d3, d4, d5, d6, label)
run_loss += loss.item()
pred = d1[:, 0, :, :]
pred[pred.ge(0.5)] = 1 # gt/lt/ge/le/eq/ne 大于/小于/大于等于/小于等于/等于/不等于
pred[pred.lt(0.5)] = 0
confusionMatrix = ConfusionMatrix(2, pred.squeeze().data, label.squeeze().data)
IOU = IntersectionOverUnion(confusionMatrix)[-1]
iou += IOU.cpu().data
if i % 20 == 19:
print('test: {epoch:%d %d/%d}: loss=%3.3f iou=%3.3f' % (epoch,i,len(test_dataloader), run_loss / (i + 1), iou / (i + 1)))
test_loss.append(run_loss / len(test_dataloader))
test_iou.append(iou / len(test_dataloader))
# 保存模型
if test_loss[-1] < best_iou:
torch.save(net.state_dict(), r'./weights_save/u2net_sm_gusui_seg_%s.pth' % str(test_loss[-1]))
best_iou = test_loss[-1]
end_time = time.time()
print("one epoch used time:", end_time - start_time)
if epoch % 20 == 19:
print('best recall:', best_iou)
x = np.arange(epoch + 1)
plt.figure()
p1 = plt.subplot(121)
plt.title('loss')
plt.plot(x, train_loss, 'b')
plt.plot(x, test_loss, 'r')
p2 = plt.subplot(122)
plt.title('iou')
plt.plot(x, train_iou, 'b')
plt.plot(x, test_iou, 'r')
# p3 = plt.subplot(133)
# plt.title('recall')
# plt.plot(x, train_recall, 'b')
# plt.plot(x, test_recall, 'r')
plt.savefig('./debug/result.png')
print('best recall:', best_iou)
x = np.arange(epochs)
plt.figure()
# p1 = plt.subplot(131)
plt.title('loss')
plt.plot(x, train_loss, 'b')
plt.plot(x, test_loss, 'r')
# p2 = plt.subplot(132)
# plt.title('acc')
# plt.plot(x, train_acc, 'b')
# plt.plot(x, test_acc, 'r')
# p3 = plt.subplot(133)
# plt.title('recall')
# plt.plot(x, train_recall, 'b')
# plt.plot(x, test_recall, 'r')
plt.savefig('./debug/result.png')
# plt.show()
pred.py
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
from u2net import U2NET
import numpy as np
import time
from tqdm import tqdm
import cv2
from tool.tool_gzz import *
from tool.read_img import OpenSlideImg, array_to_STAI
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# path_imgs = 'demo'
path_imgs = r'D:\gzz\data\BigMouse_Marrow\temp'
imgs_list = os.listdir(path_imgs)
# imgs_list = ['ST21Rf-SN-208-1-000016-1F.svs']
path_model = r'./weight/u2net_cell_seg_tensor(0.9427).pth'
path_save = r'./debug'
crop_size = 320
crop_step = 160
net = U2NET(3, 1).to(device)
net.load_state_dict(torch.load(path_model))
net.eval()
# 单张png 大图
# path=r'D:\gzz\data\hongli\dataset\dataset09-512\AI201802_R14-xxxx-CD_14-1-17_1M_x45300_y28202_w2048_h1536_3.png'
# path=r'D:\gzz\data\hongli\dataset\20200616\testset_AI201925\AI201925_R16-xxxx-CD_16-2910-22_1M.png'
# with torch.no_grad():
# img_=cv2.imread(path)
# h0, w0, _ = img_.shape
# img=cv2.cvtColor(img_,cv2.COLOR_BGR2RGB)
# print(img_.shape)
# img = SegPadding(img, crop_size, crop_step)
# # cv2.imwrite('./debug/pad.png',img)
# # print(img.shape)
# h, w, _ = img.shape
# patchs_coord = GetPatchsCoordinate(h, w, crop_size, crop_step)
# img = Normalization(img, is_transpose=True)
# print('img is ok')
#
# all_mask = []
# for i in range(len(patchs_coord)):
# [[x1, y1], [x2, y2]] = patchs_coord[i]
# data = torch.from_numpy(img[:, y1:y2, x1:x2]).float().unsqueeze(0).to(device)
# d0, d1, d2, d3, d4, d5, d6 = net(data)
#
# all_mask.append(Pred2Label(d1[:, 0, :, :]))
# print('pred is ok')
#
# mask_pad = BuildMask(all_mask, patchs_coord, h, w, crop_size, crop_step)
# mask = mask_pad[:h0, :w0]
# print('build mask is ok')
#
# _, conts_dilate, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
# cv2.drawContours(img_, conts_dilate, -1, (255, 0, 0), 1)
# cv2.imwrite('./debug/maks1.png', img_)
# 批量svs
count = 0
with torch.no_grad():
for i, name in enumerate(imgs_list):
path_file = os.path.join(path_imgs, name)
name = get_name(name) + '.png'
print(name)
slide = OpenSlideImg(path_file)
img_ds4 = slide.get_img_ds(4) # 获取下采4
img_ds41 = img_ds4[:, :, ::-1] #to BGR
h0, w0, _ = img_ds41.shape
img = SegPadding(img_ds41, crop_size, crop_step)
h, w, _ = img.shape
patchs_coord = GetPatchsCoordinate(h, w, crop_size, crop_step)
img = Normalization(img, is_transpose=True)
print('img is ok')
all_mask = []
for i in range(len(patchs_coord)):
[[x1, y1], [x2, y2]] = patchs_coord[i]
data = torch.from_numpy(img[:, y1:y2, x1:x2]).float().unsqueeze(0).to(device)
d0, d1, d2, d3, d4, d5, d6 = net(data)
all_mask.append(Pred2Label(d1[:, 0, :, :]))
# pred = d1[:, 0, :, :]
#
# pred[pred.ge(0.5)] = 1 # gt/lt/ge/le/eq/ne 大于/小于/大于等于/小于等于/等于/不等于
# pred[pred.lt(0.5)] = 0
# mask1 = pred.squeeze().cpu().data.numpy() * 255
# # print(mask1.shape)
# cv2.imwrite('./debug/maks/' + str(count) + '.png', mask1)
# count += 1
# all_mask.append(mask1)
print('pred is ok')
mask_pad = BuildMask(all_mask, patchs_coord, h, w, crop_size, crop_step)
mask = mask_pad[:h0, :w0]
print('build mask is ok')
# cv2.imwrite('./debug/b.png', mask)
_, mask_cnts, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
mask_2 = img_ds4.copy()
cv2.drawContours(mask_2, mask_cnts, -1, (0, 255, 0), -1)
result_img = cv2.addWeighted(img_ds4, 0.78, mask_2, 0.22, 1)
cv2.imwrite(os.path.join(path_save, name), result_img[:, :, ::-1])