实践经历,类似于肿瘤的图像都是3D图像类型的,而3D Image Classification 一般多用于医学领域,之前从未接触过类似的,所以此次实践过程值得记录一下
刚开始接触3D分类的时候,由于没有任何经验所以碰到了很多问题,记录方便自己查看并希望能够帮到和我一样什么都不知道的小白,我会从我将数据的处理到训练完成后的整个过程进行记录,如有错误的地方请多多包含
3D的图像识别于2D的识别存在一些细微的差别,当基本方法思想都一样,主要是对于数据的处理部分需要注意
图像预处理
其实这一部分的话,对于大多数任务来说是不需要的,但是由于我数据的特殊性(区域太小),需要将被识别的部分框出来,这样才能进行识别
我的数据是医学影像常见领域的MRI格式的图像,这批图像有原始数据,以及对于病灶的ROI部分,ROI部分是只有病灶的数据,其位置信息完全对应原始数据的病灶部分,因此需要根据ROI数据将原始数据中的肿瘤部分框出,对于预处理这一部分,因为作者也是通过网上各种查找资料才处理好,前后一共花了大概两周的时间才处理好,所以这里我只介绍用的方法,我是用的是SimpleITK这个库,它是专门对医学影像图像处理的库,请自行参考
我有用到的几个函数:
-
将MRI图像转换为矩阵
sitk.GetArrayFromImage()
-
读取图像
sitk.ReadImage(img)
-
写图像
sitk.WriteImage()
-
将矩阵转换为图像
sitk.GetImgageFromArray()
-
重采样
sitk.Resmaple()
由于时间太过久远我只给出了我仍然还记得的函数,需要用到什么请自行Google(使用VPN)
DataSet部分
Dataset部分由于我们的图像是3D的,因此需要对3D图像,这里我推荐一个库:Torchio正对于MRI影响有很多处理,还有图像增强等各种函数,但是我的数据已经做过增强因此不再使用
import SimpleITK as sitk
from torch.utils.data import Dataset
import torch
import numpy as np
import os
import torchio as tio
class Classification_Dataset(Dataset):
def __init__(self, mode='train', period=''):
self.root_path = 'your path of the data'
self.high_list = []
self.low_list = []
with open('./label.txt', 'r') as f:
for line in f.readlines():
all_label = line.strip('\n').split(':')[1]
data = line.strip('\n').split(':')[0]
if all_label == '0':
self.low_list.append(data)
else:
self.high_list.append(data) # 读取数据的label
train_list = ['110', '106', '74', '111', '30', '112', '131', '33', '56', '150', '35', '20', '139', '64', '9', '5', '138', '116', '142', '130', '28', '32', '39', '141', '12', '26', '4', '63', '120', '140', '44', '122', '124', '73', '87', '149', '85', '40', '16', '90', '127', '34', '137', '67', '1', '29', '153', '133', '126', '78', '23', '53', '144', '70', '104', '89', '15', '51', '69', '58', '7', '143', '8', '105', '10', '128', '2', '62', '13', '151', '109', '152', '108', '81', '79', '36', '145', '57', '68', '60', '115', '3', '37', '42']
val_list = ['53', '122', '32', '26', '137', '14', '55', '90', '124', '56', '35', '70', '39', '20', '131', '116', '46', '67', '43', '40', '138', '27', '129', '28', '141', '74', '123', '7', '51', '146', '69', '60', '151', '105', '143', '3', '134', '89', '81', '10', '145', '109']
print(f'Have got {len(val_list)} imgs')
assert os.path.exists(self.root_path), 'Dataset root: {} does not exist.'.format(self.root_path)
self.mode = mode
self.period = period
self.data_path = [] # G:\WPJ\data\EOB\DATA\1\CE3
self.img_path = [] # 图片路径 G:\WPJ\data\EOB\DATA\1\CE3\CE_path_333.mha
if mode == 'train':
for i in train_list:
self.data_path.append(os.path.join(self.root_path, i, period))
# self.img_path.append(os.path.join(root_path, i, period, (period + '.mha')))
for data in self.data_path:
for i in os.listdir(data):
self.img_path.append(os.path.join(data, i))
elif mode == 'val':
for i in val_list:
self.data_path.append(os.path.join(self.root_path, i, period))
self.img_path.append(os.path.join(self.root_path, i, period, (period+'.mha')))
def __getitem__(self, item):
img = sitk.ReadImage(self.img_path[item])
transform = tio.Resize((64, 64, 64))
img = transform(img)
img_arr = sitk.GetArrayFromImage(img)
img_arr = np.expand_dims(img_arr, 0)
max_num = torch.tensor(np.max(img_arr), dtype=torch.float32)
img_data = torch.tensor(img_arr, dtype=torch.float32)
img_data = img_data * (1.0 / max_num)
patient = self.img_path[item].split('\\')[-3]
if patient in self.high_list:
img_label = 1
elif patient in self.low_list:
img_label = 0
return img_data, img_label
def __len__(self):
return len(self.img_path)
模型部分
对于数据量小的图片,不建议大家使用GitHub上的3DCNN,因为实在是太深了,最开始我也使用了3D resnet10,但是效果不是很理想,因此我将模型改为了3D的LeNet5,但是我在训练中遇到了完全不能拟合的情况,指的是训练集和测试集的效果都不行,摸索了很久,加上了BN层之后就可以跑通了
class LeNet_BN(nn.Module):
def __init__(self,num_classes=2):
super(LeNet_BN, self).__init__()
self.conv1 = nn.Conv3d(1, 16, 3,padding='same')
self.bn1 = nn.BatchNorm3d(16)
self.pool1 = nn.MaxPool3d(2,stride=2)
self.conv2 = nn.Conv3d(16, 32, 3,padding='same')
self.bn2 = nn.BatchNorm3d(32)
self.pool2 = nn.MaxPool3d(2,stride=2)
self.conv3 = nn.Conv3d(32, 64, 3,padding='same')
self.bn3 = nn.BatchNorm3d(64)
self.pool3 = nn.MaxPool3d(2,stride=2)
self.fc1 = nn.Linear(64*8*8*8, 1024)
self.fc2 = nn.Linear(1024, 32)
self.fc3 = nn.Linear(32, num_classes)
def forward(self, x):
# print(x.shape)
x = F.relu(self.bn1(self.conv1(x))) # input(1, 64,64,64) output(16, 64,64,64)
# print(x.shape)
x = self.pool1(x) # output(16, 32,32,32)
# print(x.shape)
x = F.relu(self.bn2(self.conv2(x))) # output(32, 32,32,32)
# print(x.shape)
x = self.pool2(x) # output(32, 16,16,16)
# print(x.shape)
x = F.relu(self.bn3(self.conv3(x))) # output(64, 16,16,16)
# print(x.shape)
x = self.pool3(x) # output(64, 8,8,8)
# print(x.shape)
x = x.view(-1, 64*8*8*8) # output(64*8*8*8)
# print(x.shape)
x = F.relu(self.fc1(x)) # output(1024)
# print(x.shape)
x = F.relu(self.fc2(x)) # output(32)
# print(x.shape)
x = self.fc3(x)
# x = F.softmax(x, dim=1)
# x = self.fc3(x) # output(2)
return x