数据集准备
数据来源
MICCAI KiTS19(Kidney Tumor Segmentation Challenge):https://kits19.grand-challenge.org/
KiTS2019是MICCAI19的一个竞赛项目,项目的任务是对3D-CT数据进行肾脏和肾脏肿瘤的分割,官方的数据集提供了210个case作为训练集,90个case作为测试集。共有800多人报名参加了这一竞赛,最终提交的结果的team有126支,其中被认定有效的为100个记录入leaderboard。目前这一竞赛状态为开放性质的,有兴趣的可以参与一下。这一挑战将于2021年继续举办,将提供更多的数据集和标注,任务也将变得更加有挑战,有兴趣的同学可以跟进关注一下。
感谢评论里weixin_40621562老哥提供的数据集百度云版:链接: https://pan.baidu.com/s/1AOQDjPz9ye32DH-oDS0WDw 提取码: d7jk
数据预处理
KiTS19提供的数据是3D CT图像,我们要训练的是最简单的2D U-Net,因此要从3D CT体数据中读取2D切片。数据集的提供方在其Github上很贴心的提供了可视化的代码,是用python调用了nibabel
库处理.nii格式的体数据得到2D的.png格式的切片。可视化的结果如下图所示,需要对切片进行筛选。另外需要补充的是在KiTS的数据集中分割的标签有三类:背景、肾脏、肾脏肿瘤,我们想进行的是简单的背景与肾脏二分类问题而不是多分类问题,因此在可视化过程中比较简单粗暴的将肿瘤视为肾脏的一部分。
一共有210个3D CT的体数据,从每个体数据中选取了10个slice,一共得到了2100张2D的.png格式图像。
网络结构及代码
网络结构
U-Net的结构时最简单的encoder-decoder结构,再加上越级连接,详细的网络结构请见我的另一篇博客深度学习图像语义分割网络总结:U-Net与V-Net的Pytorch实现。
训练代码
代码参照了github上的https://github.com/milesial/Pytorch-UNet
import argparse
import logging
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch import optim
from tqdm import tqdm
from eval import eval_net
from unet import UNet
from visdom import Visdom
from utils.dataset import BasicDataset
from torch.utils.data import DataLoader, random_split
dir_img = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\slice_png'
dir_mask = 'D:\Dataset\CT-KiTS19\KiTS19\kits19-master\png_datasize\\train_choose\mask_png'
dir_checkpoint = 'checkpoints/'
def train_net(net,
device,
epochs=5,
batch_size=1,
lr=0.1,
val_percent=0.2,
save_cp=True,
img_scale=1):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
#writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
viz=Visdom()
viz.line([0.], [0.], win='train_loss', opts=dict(title='train_loss'))
viz.line([0.], [0.], win='learning_rate', opts=dict(title='learning_rate'))
viz.line([0.], [0.], win='Dice/test'