Pytorch-3D-Image-Classification

实践经历,类似于肿瘤的图像都是3D图像类型的,而3D Image Classification 一般多用于医学领域,之前从未接触过类似的,所以此次实践过程值得记录一下

刚开始接触3D分类的时候,由于没有任何经验所以碰到了很多问题,记录方便自己查看并希望能够帮到和我一样什么都不知道的小白,我会从我将数据的处理到训练完成后的整个过程进行记录,如有错误的地方请多多包含

3D的图像识别于2D的识别存在一些细微的差别,当基本方法思想都一样,主要是对于数据的处理部分需要注意

图像预处理

其实这一部分的话,对于大多数任务来说是不需要的,但是由于我数据的特殊性(区域太小),需要将被识别的部分框出来,这样才能进行识别
我的数据是医学影像常见领域的MRI格式的图像,这批图像有原始数据,以及对于病灶的ROI部分,ROI部分是只有病灶的数据,其位置信息完全对应原始数据的病灶部分,因此需要根据ROI数据将原始数据中的肿瘤部分框出,对于预处理这一部分,因为作者也是通过网上各种查找资料才处理好,前后一共花了大概两周的时间才处理好,所以这里我只介绍用的方法,我是用的是SimpleITK这个库,它是专门对医学影像图像处理的库,请自行参考
我有用到的几个函数:

  1. 将MRI图像转换为矩阵sitk.GetArrayFromImage()

  2. 读取图像sitk.ReadImage(img)

  3. 写图像 sitk.WriteImage()

  4. 将矩阵转换为图像 sitk.GetImgageFromArray()

  5. 重采样 sitk.Resmaple()

由于时间太过久远我只给出了我仍然还记得的函数,需要用到什么请自行Google(使用VPN)

DataSet部分

Dataset部分由于我们的图像是3D的,因此需要对3D图像,这里我推荐一个库:Torchio正对于MRI影响有很多处理,还有图像增强等各种函数,但是我的数据已经做过增强因此不再使用

import SimpleITK as sitk
from torch.utils.data import Dataset
import torch
import numpy as np
import os
import torchio as tio

class Classification_Dataset(Dataset):
    def __init__(self, mode='train', period=''):
        self.root_path = 'your path of the data'
        self.high_list = []
        self.low_list = []
        with open('./label.txt', 'r') as f:
            for line in f.readlines():
                all_label = line.strip('\n').split(':')[1]
                data = line.strip('\n').split(':')[0]
                if all_label == '0':
                    self.low_list.append(data)
                else:
                    self.high_list.append(data)       # 读取数据的label
        train_list = ['110', '106', '74', '111', '30', '112', '131', '33', '56', '150', '35', '20', '139', '64', '9', '5', '138', '116', '142', '130', '28', '32', '39', '141', '12', '26', '4', '63', '120', '140', '44', '122', '124', '73', '87', '149', '85', '40', '16', '90', '127', '34', '137', '67', '1', '29', '153', '133', '126', '78', '23', '53', '144', '70', '104', '89', '15', '51', '69', '58', '7', '143', '8', '105', '10', '128', '2', '62', '13', '151', '109', '152', '108', '81', '79', '36', '145', '57', '68', '60', '115', '3', '37', '42']
        val_list = ['53', '122', '32', '26', '137', '14', '55', '90', '124', '56', '35', '70', '39', '20', '131', '116', '46', '67', '43', '40', '138', '27', '129', '28', '141', '74', '123', '7', '51', '146', '69', '60', '151', '105', '143', '3', '134', '89', '81', '10', '145', '109']
        print(f'Have got {len(val_list)} imgs')
        assert os.path.exists(self.root_path), 'Dataset root: {} does not exist.'.format(self.root_path)
        self.mode = mode

        self.period = period
        self.data_path = []  # G:\WPJ\data\EOB\DATA\1\CE3
        self.img_path = []  # 图片路径 G:\WPJ\data\EOB\DATA\1\CE3\CE_path_333.mha
        if mode == 'train':
            for i in train_list:
                self.data_path.append(os.path.join(self.root_path, i, period))
                # self.img_path.append(os.path.join(root_path, i, period, (period + '.mha')))
            for data in self.data_path:
                for i in os.listdir(data):
                    self.img_path.append(os.path.join(data, i))
        elif mode == 'val':
            for i in val_list:
                self.data_path.append(os.path.join(self.root_path, i, period))
                self.img_path.append(os.path.join(self.root_path, i, period, (period+'.mha')))


    def __getitem__(self, item):
        img = sitk.ReadImage(self.img_path[item])
        transform = tio.Resize((64, 64, 64))
        img = transform(img)
        img_arr = sitk.GetArrayFromImage(img)
        img_arr = np.expand_dims(img_arr, 0)
        max_num = torch.tensor(np.max(img_arr), dtype=torch.float32)
        img_data = torch.tensor(img_arr, dtype=torch.float32)
        img_data = img_data * (1.0 / max_num)
        patient = self.img_path[item].split('\\')[-3]
        if patient in self.high_list:
            img_label = 1
        elif patient in self.low_list:
            img_label = 0
        return img_data, img_label
    def __len__(self):
        return len(self.img_path)

模型部分

对于数据量小的图片,不建议大家使用GitHub上的3DCNN,因为实在是太深了,最开始我也使用了3D resnet10,但是效果不是很理想,因此我将模型改为了3D的LeNet5,但是我在训练中遇到了完全不能拟合的情况,指的是训练集和测试集的效果都不行,摸索了很久,加上了BN层之后就可以跑通了

class LeNet_BN(nn.Module):
    def __init__(self,num_classes=2):
        super(LeNet_BN, self).__init__()
        self.conv1 = nn.Conv3d(1, 16, 3,padding='same')
        self.bn1 = nn.BatchNorm3d(16)
        self.pool1 = nn.MaxPool3d(2,stride=2)
        self.conv2 = nn.Conv3d(16, 32, 3,padding='same')
        self.bn2 = nn.BatchNorm3d(32)
        self.pool2 = nn.MaxPool3d(2,stride=2)
        self.conv3 = nn.Conv3d(32, 64, 3,padding='same')
        self.bn3 = nn.BatchNorm3d(64)
        self.pool3 = nn.MaxPool3d(2,stride=2)
        self.fc1 = nn.Linear(64*8*8*8, 1024)
        self.fc2 = nn.Linear(1024, 32)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        # print(x.shape)
        x = F.relu(self.bn1(self.conv1(x)))  # input(1, 64,64,64) output(16, 64,64,64)
        # print(x.shape)
        x = self.pool1(x)  # output(16, 32,32,32)
        # print(x.shape)
        x = F.relu(self.bn2(self.conv2(x)))  # output(32, 32,32,32)
        # print(x.shape)
        x = self.pool2(x)  # output(32, 16,16,16)
        # print(x.shape)
        x = F.relu(self.bn3(self.conv3(x)))  # output(64, 16,16,16)
        # print(x.shape)
        x = self.pool3(x)  # output(64, 8,8,8)
        # print(x.shape)
        x = x.view(-1, 64*8*8*8)  # output(64*8*8*8)
        # print(x.shape)
        x = F.relu(self.fc1(x))  # output(1024)
        # print(x.shape)
        x = F.relu(self.fc2(x))  # output(32)
        # print(x.shape)
        x = self.fc3(x)
        # x = F.softmax(x, dim=1)
        # x = self.fc3(x)  # output(2)
        return x
  • 1
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值