U-net医学图像分割

该博客介绍了如何使用U-net模型进行医学图像分割,特别是ISBI2012细胞检测数据集的处理。作者首先对数据集进行了划分,然后详细阐述了数据加载、U-net模型构建以及训练过程。在训练部分,提供了使用脚本的详细选项,包括保存每个epoch的图像、使用padding以保持图像尺寸、记录训练损失等信息。
摘要由CSDN通过智能技术生成

代码作者

一 数据集分割

from PIL import Image
import os

# For the Dataset register at : http://brainiac2.mit.edu/isbi_challenge/
# Download the corresponding data files
# Only 30 images are available with ground truth
# 6 Images are used for validation and are put into a seperate folder
img = Image.open('./train-volume.tif')
print('*************')

directory = './ISBI 2012/Train-Volume/'
if not os.path.exists(directory):
    os.makedirs(directory)

directory = './ISBI 2012/Val-Volume/'
if not os.path.exists(directory):
    os.makedirs(directory)
for i in range(30):
    try:
        img.seek(i)
        if i % 5 == 0:
            img.save('./ISBI 2012/Val-Volume/train-volume-%s.tif' % (i,))
        else:
            img.save('./ISBI 2012/Train-Volume/train-volume-%s.tif' % (i,))
    except EOFError:
        break
img = Image.open('./train-labels.tif')
directory = './ISBI 2012/Train-Labels/'
if not os.path.exists(directory):
    os.makedirs(directory)

directory = './ISBI 2012/Val-Labels/'
if not os.path.exists(directory):
    os.makedirs(directory)
for i in range(30):
    try:
        img.seek(i) # i frame
        if i % 5 == 0:
            img.save('./ISBI 2012/Val-Labels/train-labels-%s.tif' % (i,))
        else:
            img.save('./ISBI 2012/Train-Labels/train-labels-%s.tif' % (i,))
    except EOFError:
        break

img = Image.open('./test-volume.tif')
directory = './ISBI 2012/Test-Volume/'
if not os.path.exists(directory):
    os.makedirs(directory)
for i in range(30):
    try:
        img.seek(i)
        img.save('./ISBI 2012/Test-Volume/test-volume-%s.tif' % (i,))
    except EOFError:
        break

数据集使用的是ISBI2012细胞检测数据集,30张训练图像,选其中六张作为验证集,剩下作为训练集。由于ISBI2012训练数据比较少U-net,通过图像扭曲对数据进行augment。图像扭曲增加数据    code

二 数据集加载

import glob
from torch.utils import data
from PIL import Image
import torchvision
import numpy as np

class ISBIDataset(data.Dataset):

    def __init__(self, gloob_dir_train, gloob_dir_label, length, is_pad, eval, totensor):
        self.gloob_dir_train = gloob_dir_train
        self.gloob_dir_label = gloob_dir_label
        self.length = length
        self.crop = torchvision.transforms.CenterCrop(512)#得到期望512*512输出图像
        self.crop_nopad = torchvision.transforms.CenterCrop(324)#没有加padding得到324*324输出
        self.is_pad = is_pad
        self.eval = eval 
        self.totensor = totensor
        self.changetotensor = torchvision.transforms.ToTensor()

        self.rand_vflip = False
        self.rand_hflip = False
        self.rand_rotate = False
        self.angle = 0

    def __len__(self):
        'Denotes the total number of samples'
        return self.length

    def __getitem__(self, index):
        'Generates one sample of data'
        # files are sorted depending the last number in their filename
        # for example : "./ISBI 2012/Train-Volume/train-volume-*.tif"
        trainfiles = sorted(glob.glob(self.gloob_dir_train),
                            key=lambda name: int(name[self.gloob_dir_train.rfind('*'):
                                                      -(len(self.gloob_dir_train) - self.gloob_dir_train.rfind('.'))]))

        labelfiles = sorted(glob.glob(self.gloob_dir_label),
                            key=lambda name: int(name[self.gloob_dir_label.rfind('*'):
                                                      -(len(self.gloob_dir_label) - self.gloob_dir_label.rfind('.'))]))

        trainimg = Image.open(trainfiles[index])
        trainlabel = Image.open(labelfiles[index])


        if not self.eval:
            if self.rand_vflip:
                trainlabel = trainlabel.transpose(Image.FLIP_LEFT_RIGHT)
                trainimg = trainimg.transpose(Image.FLIP_LEFT_RIGHT)

            if self.rand_hflip:
                trainlabel = trainlabel.transpose(Image.FLIP_TOP_BOTTOM)
                trainimg = trainimg.transpose(Image.FLIP_TOP_BOTTOM)

            if self.rand_rotate:
                # Add padding to the image to remove black boarders when rotating
                # image is croped to true size later.
                trainimg = Image.fromarray(np.pad(np.asarray(trainimg), ((107, 107), (107, 107)), 'reflect'))
                trainlabel = Image.fromarray(np.pad(np.asarray(trainlabel), ((107, 107), (107, 107)), 'reflect'))

                trainlabel = trainlabel.rotate(self.angle)
                trainimg = trainimg.rotate(self.angle)
                # crop rotated image to true size
                trainlabel = self.crop(trainlabel)
                trainimg = self.crop(trainimg)


        # when padding is used, dont crop the label image
        if not self.is_pad:
            trainlabel = self.crop_nopad(trainlabel)

        if self.totensor:
            trainlabel = self.changetotensor(trainlabel).long()
            trainimg = self.changetotensor(trainimg)

        return trainimg, trainlabel

三 U-net模型

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


# 1 MODEL
class Unet(nn.Module):

    def __init__(self):
        super(Unet, self).__init__()

        # All layers which have weights are created and initlialitzed in init.
        # parameterless modules are used in functional style F. in forward
        # (object version of parameterless modules can be created with nn. init too )

        # https://pytorch.org/docs/master/nn.html#conv2d
        # in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)

        # https://pytorch.org/docs/master/nn.html#batchnorm2d
        # num_features/channels, eps, momentum, affine, track_running_stats
       
  • 5
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值