医学图像处理——实战篇(二)

引言

上一篇中,和大家一起做了一个细胞分类的小模型,相信大家都已经可以自己搭一个简单的图像分类框架。

在这一篇中,和大家一起做一个图像分割的模型。共勉!

医学专业知识

这次还是从一个大项目中扣出来的一个小功能模块——骨髓腔分割,

同样的,这部分知识 是同事辛苦整理的,望珍惜!

骨骼的基本结构:

        骨组织、骨髓、骨膜、椎间关节、周围结缔组织或肌肉

我们的目的就是要通过图像分割得到骨髓腔的区域。

数据集

数据集是公司内部的数据,不能公开。

大家如果没有现成的数据的话,可以用画图工具标几张。

图像和mask要一一对应。

模型

可以使用经典的unet、u2net、cenet做实验,

后面可以复现较新的分割论文,以到达更好的效果,学霸可以自己创新。

本篇就使用经典的u2net。

模型测试结果

最后结果整体还算是可以的,iou也很高,

但是细看的话,只能说边缘区域马马虎虎。

代码

mydataset.py


import cv2
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import imgaug as ia
import imgaug.augmenters as iaa
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import sys

ia.seed(16)

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self, image, label=None):
        tmpImg = np.zeros((image.shape[0], image.shape[1], 3))
        image = image / np.max(image)
        if image.shape[2] == 1:
            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
            tmpImg[:, :, 1] = (image[:, :, 0] - 0.485) / 0.229
            tmpImg[:, :, 2] = (image[:, :, 0] - 0.485) / 0.229
        else:  # bgr
            tmpImg[:, :, 2] = (image[:, :, 2] - 0.406) / 0.225
            tmpImg[:, :, 1] = (image[:, :, 1] - 0.456) / 0.224
            tmpImg[:, :, 0] = (image[:, :, 0] - 0.485) / 0.229
        tmpImg = tmpImg.transpose((2, 0, 1))

        if label is not None:
            tmpLbl = np.zeros((label.shape[0], label.shape[0], 1))
            label = label / 255
            tmpLbl[:, :, 0] = label
            tmpLbl = tmpLbl.transpose((2, 0, 1))

            return torch.from_numpy(tmpImg), torch.from_numpy(tmpLbl)
        else:
            return torch.from_numpy(tmpImg), None


class Dataset_imgaug(Dataset):
    def __init__(self, file_name, path_img, path_label, is_use_imgaug=True, img_size=320, transform=None):
        self.is_use_imgaug = is_use_imgaug
        self.file_name = file_name
        self.path_img = path_img
        self.path_label = path_label
        self.img_size = img_size
        self.ToTensor = ToTensor()
        if is_use_imgaug:
            self.transform = iaa.Sequential([
                iaa.Fliplr(p=0.5),  # 水平镜面翻转
                iaa.Flipud(p=0.5),  # 上下镜面翻转
                iaa.SomeOf((0, 5),  # 代表每次从中选择0~5个方法增强图像
                           [iaa.Multiply(),
                            iaa.Sharpen(),  # 图像锐化
                            iaa.contrast.GammaContrast(),  # 伽马对比度
                            iaa.imgcorruptlike.Brightness(),  # 调整图像亮度
                            iaa.ElasticTransformation(),
                            iaa.Affine(  # 对一部分图像做仿射变换
                                scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},  # 图像缩放为80%到120%之间
                                translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},  # 平移±20%之间
                                rotate=(-45, 45),  # 旋转±45度之间
                                shear=(-16, 16),  # 剪切变换±16度,(矩形变平行四边形)
                                order=[0, 1],  # 使用最邻近差值或者双线性差值
                                cval=0,  # 全黑填充"constant"
                                mode="constant"),  # mode=ia.ALL    #定义填充图像外区域的方法
                            iaa.CropAndPad(
                                px=(-80, 80),
                                pad_cval=0,
                                pad_mode="constant",
                                keep_size=True,
                                sample_independently=False),
                            ]),
                iaa.Resize(img_size),
            ])
        else:
            if transform is not None:
                self.transform = transform
            else:
                self.transform = self.ToTensor


    def __len__(self):
        return len(self.file_name)

    def __getitem__(self, item):
        _name = self.file_name[item]

        _img = cv2.imread(os.path.join(self.path_img, _name))
        if self.path_label is not None:
            _label = cv2.imread(os.path.join(self.path_label, _name), 0)
        else:
            _label =None
        if self.is_use_imgaug:
            _label = SegmentationMapsOnImage(_label, shape=_img.shape)
            _img, _label = self.transform(image=_img, segmentation_maps=_label)
            _label = _label.get_arr()
            _img, _label = self.ToTensor(_img, _label)

        else:
            _img, _label = self.transform(_img, _label)
            #TODO 待优化  不同的transform需要不同的输入

        if _label is not None:
            return _img.float(), _label.float()
        else:
            return _img.float(),None

train.py


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, recall_score
from dataset_imgaug import Dataset_imgaug
import argparse
import yaml
from u2net import U2NET
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm import tqdm
from tool.tool_gzz import *
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', default='config.yaml', type=str)
    args = parser.parse_args()
    # print(args.config)

    # 获取参数
    print('load parser')
    with open(args.config, errors='ignore') as f:
        config = yaml.safe_load(f)
    path_imgs = config.get('path_imgs')
    path_labels = config.get('path_labels')
    batch_size = config.get('batch_size')
    epochs = config.get('epochs')
    lr = config.get('lr')
    in_channel = config.get('in_channel')
    out_channel = config.get('out_channel')
    image_size = config.get('image_size')
    best_iou = config.get('best_iou')
    checkpoint = config.get('checkpoint')
    warmup_step = config.get('warmup_step')
    cha = (lr - 0.000001) / warmup_step

    # 划分数据集
    imgs_list = os.listdir(path_imgs)
    # labels_list = os.listdir(path_labels)
    train_data, test_data = train_test_split(imgs_list, test_size=0.2, random_state=16)

    train_dataset = Dataset_imgaug(train_data, path_imgs, path_labels, is_use_imgaug=False)
    test_dataset = Dataset_imgaug(test_data, path_imgs, path_labels, is_use_imgaug=False)

    train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=0)
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=0)

    # 声明模型
    print('creat mode')
    net = U2NET(in_channel, out_channel).to(device)
    optimizer = optim.Adam(net.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
    criterion = nn.BCELoss(size_average=True).to(device)  # 二元交叉熵

    train_loss, test_loss = [], []
    train_iou, test_iou = [], []

    for epoch in range(epochs):
        start_time = time.time()
        # 训练
        net.train()
        run_loss, iou = 0.0, 0.0
        # 手动调整学习率
        # if epoch <= warmup_step:
        #     _lr = 0.000001 + epoch * cha
        #     for param_group in optimizer.param_groups:
        #         param_group['lr'] = _lr

        for i, (data, label) in enumerate(train_dataloader):
            data = data.to(device)
            label = label.to(device)
            optimizer.zero_grad()
            d0, d1, d2, d3, d4, d5, d6 = net(data)
            loss0, loss = Muti_BCELOSS(criterion, d0, d1, d2, d3, d4, d5, d6, label)
            loss.backward()
            optimizer.step()
            run_loss = run_loss + loss.item()

            pred = d1[:, 0, :, :]
            pred[pred.ge(0.5)] = 1  # gt/lt/ge/le/eq/ne  大于/小于/大于等于/小于等于/等于/不等于
            pred[pred.lt(0.5)] = 0
            confusionMatrix = ConfusionMatrix(2, pred.squeeze().data, label.squeeze().data)
            IOU = IntersectionOverUnion(confusionMatrix)[-1]
            iou += IOU.cpu().data

#             print('train: {epoch:%d}: loss=%3.3f' % (epoch, run_loss / (i + 1)))
            if i % 20 == 19:
                print('train: {epoch:%d %d/%d}: loss=%3.3f  iou=%3.3f' % (epoch,i,len(train_dataloader), run_loss / (i + 1), iou / (i + 1)))

        train_loss.append(run_loss / len(train_dataloader))
        train_iou.append(iou / len(train_dataloader))

        # 测试
        net.eval()
        with torch.no_grad():
            run_loss, iou = 0.0, 0.0
            for i, (data, label) in enumerate(test_dataloader):
                data = data.to(device)
                label = label.to(device)
                d0, d1, d2, d3, d4, d5, d6 = net(data)
                loss0, loss = Muti_BCELOSS(criterion, d0, d1, d2, d3, d4, d5, d6, label)
                run_loss += loss.item()

                pred = d1[:, 0, :, :]
                pred[pred.ge(0.5)] = 1  # gt/lt/ge/le/eq/ne  大于/小于/大于等于/小于等于/等于/不等于
                pred[pred.lt(0.5)] = 0
                confusionMatrix = ConfusionMatrix(2, pred.squeeze().data, label.squeeze().data)
                IOU = IntersectionOverUnion(confusionMatrix)[-1]
                iou += IOU.cpu().data

                if i % 20 == 19:
                    print('test: {epoch:%d %d/%d}: loss=%3.3f  iou=%3.3f' % (epoch,i,len(test_dataloader), run_loss / (i + 1), iou / (i + 1)))

        test_loss.append(run_loss / len(test_dataloader))
        test_iou.append(iou / len(test_dataloader))

        # 保存模型
        if test_loss[-1] < best_iou:
            torch.save(net.state_dict(), r'./weights_save/u2net_sm_gusui_seg_%s.pth' % str(test_loss[-1]))
            best_iou = test_loss[-1]

        end_time = time.time()
        print("one epoch used time:", end_time - start_time)

        if epoch % 20 == 19:
            print('best recall:', best_iou)
            x = np.arange(epoch + 1)
            plt.figure()
            p1 = plt.subplot(121)
            plt.title('loss')
            plt.plot(x, train_loss, 'b')
            plt.plot(x, test_loss, 'r')
            p2 = plt.subplot(122)
            plt.title('iou')
            plt.plot(x, train_iou, 'b')
            plt.plot(x, test_iou, 'r')
            # p3 = plt.subplot(133)
            # plt.title('recall')
            # plt.plot(x, train_recall, 'b')
            # plt.plot(x, test_recall, 'r')
            plt.savefig('./debug/result.png')

    print('best recall:', best_iou)
    x = np.arange(epochs)
    plt.figure()
    # p1 = plt.subplot(131)
    plt.title('loss')
    plt.plot(x, train_loss, 'b')
    plt.plot(x, test_loss, 'r')
    # p2 = plt.subplot(132)
    # plt.title('acc')
    # plt.plot(x, train_acc, 'b')
    # plt.plot(x, test_acc, 'r')
    # p3 = plt.subplot(133)
    # plt.title('recall')
    # plt.plot(x, train_recall, 'b')
    # plt.plot(x, test_recall, 'r')
    plt.savefig('./debug/result.png')
    # plt.show()

pred.py


import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
import torch.nn as nn
from u2net import U2NET
import numpy as np
import time
from tqdm import tqdm
import cv2
from tool.tool_gzz import *
from tool.read_img import OpenSlideImg, array_to_STAI

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# path_imgs = 'demo'
path_imgs = r'D:\gzz\data\BigMouse_Marrow\temp'
imgs_list = os.listdir(path_imgs)
# imgs_list = ['ST21Rf-SN-208-1-000016-1F.svs']

path_model = r'./weight/u2net_cell_seg_tensor(0.9427).pth'
path_save = r'./debug'

crop_size = 320
crop_step = 160

net = U2NET(3, 1).to(device)
net.load_state_dict(torch.load(path_model))
net.eval()


# 单张png 大图
# path=r'D:\gzz\data\hongli\dataset\dataset09-512\AI201802_R14-xxxx-CD_14-1-17_1M_x45300_y28202_w2048_h1536_3.png'
# path=r'D:\gzz\data\hongli\dataset\20200616\testset_AI201925\AI201925_R16-xxxx-CD_16-2910-22_1M.png'
# with torch.no_grad():
#     img_=cv2.imread(path)
#     h0, w0, _ = img_.shape
#     img=cv2.cvtColor(img_,cv2.COLOR_BGR2RGB)
#     print(img_.shape)
#     img = SegPadding(img, crop_size, crop_step)
#     # cv2.imwrite('./debug/pad.png',img)
#     # print(img.shape)
#     h, w, _ = img.shape
#     patchs_coord = GetPatchsCoordinate(h, w, crop_size, crop_step)
#     img = Normalization(img, is_transpose=True)
#     print('img is ok')
#
#     all_mask = []
#     for i in range(len(patchs_coord)):
#         [[x1, y1], [x2, y2]] = patchs_coord[i]
#         data = torch.from_numpy(img[:, y1:y2, x1:x2]).float().unsqueeze(0).to(device)
#         d0, d1, d2, d3, d4, d5, d6 = net(data)
#
#         all_mask.append(Pred2Label(d1[:, 0, :, :]))
#     print('pred is ok')
#
#     mask_pad = BuildMask(all_mask, patchs_coord, h, w, crop_size, crop_step)
#     mask = mask_pad[:h0, :w0]
#     print('build mask is ok')
#
#     _, conts_dilate, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
#     cv2.drawContours(img_, conts_dilate, -1, (255, 0, 0), 1)
#     cv2.imwrite('./debug/maks1.png', img_)




# 批量svs
count = 0
with torch.no_grad():
    for i, name in enumerate(imgs_list):
        path_file = os.path.join(path_imgs, name)
        name = get_name(name) + '.png'
        print(name)
        slide = OpenSlideImg(path_file)
        img_ds4 = slide.get_img_ds(4)  # 获取下采4
        img_ds41 = img_ds4[:, :, ::-1]  #to BGR

        h0, w0, _ = img_ds41.shape
        img = SegPadding(img_ds41, crop_size, crop_step)
        h, w, _ = img.shape
        patchs_coord = GetPatchsCoordinate(h, w, crop_size, crop_step)
        img = Normalization(img, is_transpose=True)
        print('img is ok')

        all_mask = []
        for i in range(len(patchs_coord)):
            [[x1, y1], [x2, y2]] = patchs_coord[i]
            data = torch.from_numpy(img[:, y1:y2, x1:x2]).float().unsqueeze(0).to(device)
            d0, d1, d2, d3, d4, d5, d6 = net(data)

            all_mask.append(Pred2Label(d1[:, 0, :, :]))

            # pred = d1[:, 0, :, :]
            #
            # pred[pred.ge(0.5)] = 1  # gt/lt/ge/le/eq/ne  大于/小于/大于等于/小于等于/等于/不等于
            # pred[pred.lt(0.5)] = 0
            # mask1 = pred.squeeze().cpu().data.numpy() * 255
            # # print(mask1.shape)
            # cv2.imwrite('./debug/maks/' + str(count) + '.png', mask1)
            # count += 1
            # all_mask.append(mask1)
        print('pred is ok')

        mask_pad = BuildMask(all_mask, patchs_coord, h, w, crop_size, crop_step)
        mask = mask_pad[:h0, :w0]
        print('build mask is ok')
        # cv2.imwrite('./debug/b.png', mask)

        _, mask_cnts, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
        mask_2 = img_ds4.copy()
        cv2.drawContours(mask_2, mask_cnts, -1, (0, 255, 0), -1)

        result_img = cv2.addWeighted(img_ds4, 0.78, mask_2, 0.22, 1)

        cv2.imwrite(os.path.join(path_save, name), result_img[:, :, ::-1])

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值