[pytorch]医学图像之肝脏语义分割(训练+预测代码)

一,Unet结构:

结合上图的Unet结构,pytorch的unet代码如下:

unet.py:

import torch.nn as nn
import torch
from torch import autograd


class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        return self.conv(input)


class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()

        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

二,数据集:肝脏图

三,代码

主代码:main.py

import numpy as np
import torch
import argparse
from torch.utils.data import DataLoader
from torch import autograd, optim
from torchvision.transforms import transforms
from unet import Unet
from dataset import LiverDataset
from mIou import *
import os
import cv2
# 是否使用cuda
import PIL.Image as Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_data(i):
    import dataset
    imgs = dataset.make_dataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\val")
    imgx = []
    imgy = []
    for img in imgs:
        imgx.append(img[0])
        imgy.append(img[1])
    return imgx[i],imgy[i]



def train_model(model, criterion, optimizer, dataload, num_epochs=21):
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        dt_size = len(dataload.dataset)
        epoch_loss = 0
        step = 0
        for x, y in dataload:
            step += 1
            inputs = x.to(device)
            labels = y.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            print("%d/%d,train_loss:%0.3f" % (step, (dt_size - 1) // dataload.batch_size + 1, loss.item()))
        print("epoch %d loss:%0.3f" % (epoch, epoch_loss))
    #torch.save(model.state_dict(), './weights_%d.pth' % epoch)
    torch.save(model.state_dict(), r'H:\BaiduNetdisk\BaiduDownload\u_net_liver-master/weights.pth')
    return model


# 训练模型
def train():
    model = Unet(3, 1).to(device)
    batch_size = args.batch_size
    criterion = torch.nn.BCELoss()
    optimizer = optim.Adam(model.parameters())
    liver_dataset = LiverDataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\train", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    train_model(model, criterion, optimizer, dataloaders)


# 显示模型的输出结果
def test():
    model = Unet(3, 1).to(device)   #unet输入是三通道,输出是一通道,因为不算上背景只有肝脏一个类别
    model.load_state_dict(torch.load(args.ckp, map_location='cpu'))  #载入训练好的模型
    liver_dataset = LiverDataset(r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\data\val", transform=x_transforms, target_transform=y_transforms)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    import matplotlib.pyplot as plt
    plt.ion() #开启动态模式



    with torch.no_grad():
        i=0   #验证集中第i张图
        miou_total = 0
        num = len(dataloaders)  #验证集图片的总数
        for x, _  in dataloaders:
            x = x.to(device)
            y = model(x)
            
            img_y = torch.squeeze(y).cpu().numpy()  #输入损失函数之前要把预测图变成numpy格式,且为了跟训练图对应,要额外加多一维表示batchsize
            mask = get_data(i)[1]    #得到当前mask的路径
            miou_total += get_iou(mask,img_y)  #获取当前预测图的miou,并加到总miou中
            plt.subplot(121)
            plt.imshow(Image.open(get_data(i)[0]))  
            plt.subplot(122)
            plt.imshow(img_y)
            plt.pause(0.01)
            if i < num:i+=1   #处理验证集下一张图
        plt.show()
        print('Miou=%f' % (miou_total / 20))

if __name__ =="__main__":
    x_transforms = transforms.Compose([
        transforms.ToTensor(),  # -> [0,1]
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # ->[-1,1]
    ])

    # mask只需要转换为tensor
    y_transforms = transforms.ToTensor()

    # 参数解析器,用来解析从终端读取的命令
    parse = argparse.ArgumentParser()
    #parse = argparse.ArgumentParser()
    parse.add_argument("--action", type=str, help="train or test",default="train")
    parse.add_argument("--batch_size", type=int, default=1)
    parse.add_argument("--ckp", type=str, help="the path of model weight file")
    args = parse.parse_args()

    # train
    # train()        #测试时,就把此train()语句注释掉

    # test()
    args.ckp = r"H:\BaiduNetdisk\BaiduDownload\u_net_liver-master\weights.pth"
    test()

获取数据代码:

dataset.py:

import torch.utils.data as data
import PIL.Image as Image
import os


def make_dataset(root):
    imgs = []
    n = len(os.listdir(root)) // 2  #因为数据集中一套训练数据包含有训练图和mask图,所以要除2
    for i in range(n):
        img = os.path.join(root, "%03d.png" % i)
        mask = os.path.join(root, "%03d_mask.png" % i)
        imgs.append((img, mask))
    return imgs


class LiverDataset(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None):
        imgs = make_dataset(root)
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        x_path, y_path = self.imgs[index]
        origin_x = Image.open(x_path)
        origin_y = Image.open(y_path)
        if self.transform is not None:
            img_x = self.transform(origin_x)
        if self.target_transform is not None:
            img_y = self.target_transform(origin_y)
        return img_x, img_y

    def __len__(self):
        return len(self.imgs)

四,网络结果:

网络结果一般用

1.直观效果 或者

2.指标来指定

语义分割中有一个重要的指标就是miou,平均交并比,其中miou的代码如下:

mIou.py:

import cv2
import numpy as np

class IOUMetric:
    """
    Class to calculate mean-iou using fast_hist method
    """

    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.hist = np.zeros((num_classes, num_classes))

    def _fast_hist(self, label_pred, label_true):
        mask = (label_true >= 0) & (label_true < self.num_classes)
        hist = np.bincount(
            self.num_classes * label_true[mask].astype(int) +
            label_pred[mask], minlength=self.num_classes ** 2).reshape(self.num_classes, self.num_classes)
        return hist

    def add_batch(self, predictions, gts):
        for lp, lt in zip(predictions, gts):
            self.hist += self._fast_hist(lp.flatten(), lt.flatten())

    def evaluate(self):
        acc = np.diag(self.hist).sum() / self.hist.sum()
        acc_cls = np.diag(self.hist) / self.hist.sum(axis=1)
        acc_cls = np.nanmean(acc_cls)
        iu = np.diag(self.hist) / (self.hist.sum(axis=1) + self.hist.sum(axis=0) - np.diag(self.hist))
        mean_iu = np.nanmean(iu)
        freq = self.hist.sum(axis=1) / self.hist.sum()
        fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
        return acc, acc_cls, iu, mean_iu, fwavacc



def get_iou(mask_name,predict):
    image_mask = cv2.imread(mask_name,0)
    # print(image.shape)
    height = predict.shape[0]
    weight = predict.shape[1]
    # print(height*weight)
    o = 0
    for row in range(height):
            for col in range(weight):
                if predict[row, col] < 0.5:  #由于输出的predit是0~1范围的,其中值越靠近1越被网络认为是肝脏目标,所以取0.5为阈值
                    predict[row, col] = 0
                else:
                    predict[row, col] = 1
                if predict[row, col] == 0 or predict[row, col] == 1:
                    o += 1
    height_mask = image_mask.shape[0]
    weight_mask = image_mask.shape[1]
    for row in range(height_mask):
            for col in range(weight_mask):
                if image_mask[row, col] < 125:   #由于mask图是黑白的灰度图,所以少于125的可以看作是黑色
                    image_mask[row, col] = 0
                else:
                    image_mask[row, col] = 1
                if image_mask[row, col] == 0 or image_mask[row, col] == 1:
                    o += 1
    predict = predict.astype(np.int16)

    Iou = IOUMetric(2)  #2表示类别,肝脏类+背景类
    Iou.add_batch(predict, image_mask)
    a, b, c, d, e= Iou.evaluate()
    print('%s:iou=%f' % (mask_name,d))
    return d

 

五,运行结果:

验证集的miou:

六,代码和数据集获取:

代码和数据集都在以下链接了:

the Liver dataset: link:https://pan.baidu.com/s/1FljGCVzu7HPYpwAKvSVN4Q keyword:5l88

扩展代码和数据集,这个连接是集成形式的,集成了多个unet变体和数据集:https://github.com/Andy-zhujunwen/UNET-ZOO 

  • 41
    点赞
  • 281
    收藏
    觉得还不错? 一键收藏
  • 84
    评论
UNet是一种经典的深度学习模型,特别适用于医学图像分割任务。它结合了卷积神经网络和全卷积网络的优点,在医学图像分割任务中取得了良好的效果。 UNet的结构分为下采样和上采样两部分。下采样部分由卷积和池化层组成,用于捕捉图像的全局特征信息。通过逐步减小图像的尺寸,可以提取出更加抽象的特征。上采样部分由反卷积和卷积层组成,用于恢复图像的分辨率,并产生与原始图像相同分辨率的预测结果。通过跳跃连接将下采样和上采样部分的特征图连接在一起,可以保留更多的细节信息。 在医学图像分割中,UNet可以有效地提取图像中各种器官、组织和病变的轮廓和边缘信息,有助于医生进行疾病诊断和治疗。例如,在肿瘤分割中,UNet可以准确地分割出肿瘤的边界,帮助医生判断肿瘤的恶性程度和定位手术切除范围。 与传统的图像分割方法相比,UNet具有以下优势。首先,它可以自动学习图像中的特征表示,无需手工设计特征。其次,UNet结构中的跳跃连接可以保留更多的细节信息,有助于提高分割结果的准确性。此外,UNet还可以快速训练和推理,适用于处理大规模的医学图像数据。 总之,UNet是一种强大的医学图像分割模型,通过结合卷积神经网络和全卷积网络的优点,可以准确地提取医学图像中的关键信息,为医生的疾病诊断和治疗提供帮助。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 84
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值