HCP_DataLoader

鉴于数据太大,肯要用S3来存
S3的读取速度应该还是很快的

from torch.utils.data import Dataset,DataLoader
import torch
import numpy as np
import nibabel as nbl

HCPDataListPath = '/home/ec2-user/SageMaker/Models_HCP/dt.txt'
HCPDataRootPath = '/home/ec2-user/SageMaker/HCP_dataset/'
labelNumber = {'EMOTION': 0, 'GAMBLING': 1, 'LANGUAGE': 2, 'MOTOR': 3,
               'RELATIONAL': 4, 'SOCIAL': 5, 'WM': 6}

def getLabelList(HCPDataPathList):
    labelList = []
    for HCPDataPath in HCPDataPathList:
        labelList.append(labelNumber[HCPDataPath.split('_')[1].split('.')[0]])
    return np.array(labelList)

def getData(HCPDataRootPath, dataPath):
    dtseries = nbl.load(HCPDataRootPath + dataPath)
    time_series = dtseries.get_fdata().reshape((-1))
    shape = dtseries.header.matrix.get_index_map(1).volume.volume_dimensions
    nifti = np.zeros(shape)
    for bm in dtseries.header.matrix.get_index_map(1).brain_models:
        if bm.model_type == 'CIFTI_MODEL_TYPE_SURFACE':
            continue
        voxels = bm.voxel_indices_ijk
        off, cnt = bm.index_offset, bm.index_count
        nifti[tuple(np.transpose(voxels))] = time_series[off:off + cnt]
    return np.array(nifti).reshape((1, 91, 109, 91)).astype(np.float32) 

class HCPDataSet(DataSet):
    def __init__(self,HCPDataRootPath,HCPDataPathList,HCPLabelList):
        self.HCPDataRootPath = HCPDataRootPath
        self.HCPDataPathList = np.array(HCPDataPathList)
        self.HCPLabelList = np.array(HCPLabelList)
        self.total = len(HCPDataPathList)
        
    def __getitem__(self,index):
        dataPath = self.HCPDataPathList[index]
        data = getData(self.HCPDataRootPath,dataPath)
        label = np.array(self.HCPLabelList[index])
        return torch.from_numpy(data),torch.from_numpy(label)
    
    def __len__(self):
        return self.total

class WholeDataSet():
    def __init__(self, trainDataSet, evalDataSet, testDataSet):
        self.trainDataSet = trainDataSet
        self.evalDataSet = evalDataSet
        self.testDataSet = testDataSet

def getHCPDataSet(HCPDataRootPath,HCPDataListPath,evalRate=0.2,testRate=0.2,tiny_data=0):   
    dataPathList = []
    with open(HCPDataListPath,'r') as fr:
        for HCPDataPath in fr.readlines():
            dataPathList.append(HCPDataPath.strip())
    
    if tiny_data != 0:
        dataPathList = dataPathList[:int(tiny_data)]
        
    totalNumber = len(dataPathList)
    totalTraining = int(totalNumber * (1 - evalRate - testRate))
    totalEvaluation = int(totalNumber * evalRate)
    print('Training : {},Evaluation: {},Test: {}'.format(totalTraining,totalEvaluation,totalNumber-totalTraining-totalEvaluation))
    
    trainDataPathList = dataPathList[:totalTraining]
    trainLabelList = getLabelList(trainDataPathList)
    
    evalDataPathList = dataPathList[totalTraining:totalTraining + totalEvaluation]
    evalLabelList = getLabelList(evalDataPathList)
    
    testDataPathList = dataPathList[totalTraining+totalEvaluation:]
    testLabelList = getLabelList(testDataPathList)
    
    trainDataSet = HCPDataSet(HCPDataRootPath, trainDataPathList, trainLabelList)
    evalDataSet = HCPDataSet(HCPDataRootPath, evalDataPathList, evalLabelList)
    testDataSet = HCPDataSet(HCPDataRootPath, testDataPathList, testLabelList)
    
    dataSet = WholeDataSet(trainDataSet,evalDataSet,testDataSet)
    return dataSet

dataSet = getHCPDataSet(HCPDataRootPath, HCPDataListPath, evalRate=0.05, testRate=0.15)
def getData(HCPDataRootPath, dataPath):
    data = nbl.load(HCPDataRootPath + dataPath).get_fdata().reshape((-1))
    data = np.array(data).reshape((1, 91, 109, 91)).astype(np.float32) 
    data = (data - np.min(data))/np.max(data) - np.min(data)
    return data
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值