一 数据集分割
from PIL import Image
import os
# For the Dataset register at : http://brainiac2.mit.edu/isbi_challenge/
# Download the corresponding data files
# Only 30 images are available with ground truth
# 6 Images are used for validation and are put into a seperate folder
img = Image.open('./train-volume.tif')
print('*************')
directory = './ISBI 2012/Train-Volume/'
if not os.path.exists(directory):
os.makedirs(directory)
directory = './ISBI 2012/Val-Volume/'
if not os.path.exists(directory):
os.makedirs(directory)
for i in range(30):
try:
img.seek(i)
if i % 5 == 0:
img.save('./ISBI 2012/Val-Volume/train-volume-%s.tif' % (i,))
else:
img.save('./ISBI 2012/Train-Volume/train-volume-%s.tif' % (i,))
except EOFError:
break
img = Image.open('./train-labels.tif')
directory = './ISBI 2012/Train-Labels/'
if not os.path.exists(directory):
os.makedirs(directory)
directory = './ISBI 2012/Val-Labels/'
if not os.path.exists(directory):
os.makedirs(directory)
for i in range(30):
try:
img.seek(i) # i frame
if i % 5 == 0:
img.save('./ISBI 2012/Val-Labels/train-labels-%s.tif' % (i,))
else:
img.save('./ISBI 2012/Train-Labels/train-labels-%s.tif' % (i,))
except EOFError:
break
img = Image.open('./test-volume.tif')
directory = './ISBI 2012/Test-Volume/'
if not os.path.exists(directory):
os.makedirs(directory)
for i in range(30):
try:
img.seek(i)
img.save('./ISBI 2012/Test-Volume/test-volume-%s.tif' % (i,))
except EOFError:
break
数据集使用的是ISBI2012细胞检测数据集,30张训练图像,选其中六张作为验证集,剩下作为训练集。由于ISBI2012训练数据比较少U-net,通过图像扭曲对数据进行augment。图像扭曲增加数据 code
二 数据集加载
import glob
from torch.utils import data
from PIL import Image
import torchvision
import numpy as np
class ISBIDataset(data.Dataset):
def __init__(self, gloob_dir_train, gloob_dir_label, length, is_pad, eval, totensor):
self.gloob_dir_train = gloob_dir_train
self.gloob_dir_label = gloob_dir_label
self.length = length
self.crop = torchvision.transforms.CenterCrop(512)#得到期望512*512输出图像
self.crop_nopad = torchvision.transforms.CenterCrop(324)#没有加padding得到324*324输出
self.is_pad = is_pad
self.eval = eval
self.totensor = totensor
self.changetotensor = torchvision.transforms.ToTensor()
self.rand_vflip = False
self.rand_hflip = False
self.rand_rotate = False
self.angle = 0
def __len__(self):
'Denotes the total number of samples'
return self.length
def __getitem__(self, index):
'Generates one sample of data'
# files are sorted depending the last number in their filename
# for example : "./ISBI 2012/Train-Volume/train-volume-*.tif"
trainfiles = sorted(glob.glob(self.gloob_dir_train),
key=lambda name: int(name[self.gloob_dir_train.rfind('*'):
-(len(self.gloob_dir_train) - self.gloob_dir_train.rfind('.'))]))
labelfiles = sorted(glob.glob(self.gloob_dir_label),
key=lambda name: int(name[self.gloob_dir_label.rfind('*'):
-(len(self.gloob_dir_label) - self.gloob_dir_label.rfind('.'))]))
trainimg = Image.open(trainfiles[index])
trainlabel = Image.open(labelfiles[index])
if not self.eval:
if self.rand_vflip:
trainlabel = trainlabel.transpose(Image.FLIP_LEFT_RIGHT)
trainimg = trainimg.transpose(Image.FLIP_LEFT_RIGHT)
if self.rand_hflip:
trainlabel = trainlabel.transpose(Image.FLIP_TOP_BOTTOM)
trainimg = trainimg.transpose(Image.FLIP_TOP_BOTTOM)
if self.rand_rotate:
# Add padding to the image to remove black boarders when rotating
# image is croped to true size later.
trainimg = Image.fromarray(np.pad(np.asarray(trainimg), ((107, 107), (107, 107)), 'reflect'))
trainlabel = Image.fromarray(np.pad(np.asarray(trainlabel), ((107, 107), (107, 107)), 'reflect'))
trainlabel = trainlabel.rotate(self.angle)
trainimg = trainimg.rotate(self.angle)
# crop rotated image to true size
trainlabel = self.crop(trainlabel)
trainimg = self.crop(trainimg)
# when padding is used, dont crop the label image
if not self.is_pad:
trainlabel = self.crop_nopad(trainlabel)
if self.totensor:
trainlabel = self.changetotensor(trainlabel).long()
trainimg = self.changetotensor(trainimg)
return trainimg, trainlabel
三 U-net模型
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# 1 MODEL
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
# All layers which have weights are created and initlialitzed in init.
# parameterless modules are used in functional style F. in forward
# (object version of parameterless modules can be created with nn. init too )
# https://pytorch.org/docs/master/nn.html#conv2d
# in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias
self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)
# https://pytorch.org/docs/master/nn.html#batchnorm2d
# num_features/channels, eps, momentum, affine, track_running_stats