实践经历,类似于肿瘤的图像都是3D图像类型的,而3D Image Classification 一般多用于医学领域,之前从未接触过类似的,所以此次实践过程值得记录一下
刚开始接触3D分类的时候,由于没有任何经验所以碰到了很多问题,记录方便自己查看并希望能够帮到和我一样什么都不知道的小白,我会从我将数据的处理到训练完成后的整个过程进行记录,如有错误的地方请多多包含
3D的图像识别于2D的识别存在一些细微的差别,当基本方法思想都一样,主要是对于数据的处理部分需要注意
图像预处理
其实这一部分的话,对于大多数任务来说是不需要的,但是由于我数据的特殊性(区域太小),需要将被识别的部分框出来,这样才能进行识别
我的数据是医学影像常见领域的MRI格式的图像,这批图像有原始数据,以及对于病灶的ROI部分,ROI部分是只有病灶的数据,其位置信息完全对应原始数据的病灶部分,因此需要根据ROI数据将原始数据中的肿瘤部分框出,对于预处理这一部分,因为作者也是通过网上各种查找资料才处理好,前后一共花了大概两周的时间才处理好,所以这里我只介绍用的方法,我是用的是SimpleITK这个库,它是专门对医学影像图像处理的库,请自行参考
我有用到的几个函数:
-
将MRI图像转换为矩阵
sitk.GetArrayFromImage()
-
读取图像
sitk.ReadImage(img)
-
写图像
sitk.WriteImage()
-
将矩阵转换为图像
sitk.GetImgageFromArray()
-
重采样
sitk.Resmaple()
由于时间太过久远我只给出了我仍然还记得的函数,需要用到什么请自行Google(使用VPN)
DataSet部分
Dataset部分由于我们的图像是3D的,因此需要对3D图像,这里我推荐一个库:Torchio正对于MRI影响有很多处理,还有图像增强等各种函数,但是我的数据已经做过增强因此不再使用
import SimpleITK as sitk
from torch.utils.data import Dataset
import torch
import numpy as np
import os
import torchio as tio
class Classification_Dataset(Dataset):
def __init__(self, mode='train', period=''):
self.root_path = 'your path of the data'
self.high_list = []
self.low_list = []
with open('./label.txt', 'r') as f:
for line in f.readlines():
all_label = line.strip('\n').split(':')[1]
data = line.strip('\n').split(':')[0]
if all_label == '0':
self.low_list.append(data)
else:
self.high_list.append(data) # 读取数据的label
train_list = ['110', '106', '74', '111', '30', '112', '131', '33', '56', '150', '35', '20', '139', '64', '9', '5', '138', '116', '