pytorch关于数据载入的代码

 第一步:读取文件

class PET_dataset(Dataset):
    def __init__(self,path):
        self.image_path_1 = os.path.join(path,'train1')
        self.image_path_0 = os.path.join(path,'train0')
        self.test_path_1 = os.path.join(path,'test1')
        self.test_path_0 = os.path.join(path,'test0')
        self.img_path_1 = sorted(os.listdir(self.image_path_1))
        self.img_path_0 = sorted(os.listdir(self.image_path_0))
        self.test_path_1 = sorted(os.listdir(self.test_path_1))
        self.test_path_0 = sorted(os.listdir(self.test_path_0))

定义__init__后,执行实例化的过程须变成PET_dataset(path)新建的实例本身,连带其中的参数,会一并传给__init__函数自动并执行它。所以__init__函数的参数列表会在开头多出一项,它永远指代新建的那个实例对象,Python语法要求这个参数必须要有,而名称随意,习惯上就命为self

 第二步:获得文件长度

    def __len__(self):
        return len(self.img_path_0)

第三步:读取文件内数据

    def __getitem__(self, item):
        train1_list, train0_list = self.img_path_1[item], self.img_path_0[item]
        #获得文件内的item
        train1_dcm = os.path.join(self.image_path_1, train1_list)#,allow_pickle=True)
        # img_array=read_data(img_dcm)
        #读取item
        train1_array = nib.load(train1_dcm)              #根据哪个库提取数据
        train1_array = np.array(train1_array.dataobj)    #数据转成numpy

        #label 如上
        train0_dcm = os.path.join(self.image_path_0, train0_list)#,allow_pickle=True)
        # lab_array=read_data(lab_dcm)
        train0_array=nib.load(train0_dcm)
        train0_array = np.array(train0_array.dataobj)


        #将numpy转化为tensor
        train1 = torch.FloatTensor(train1_array)
        train0 = torch.FloatTensor(train0_array)
        #在最前面加一个维度,CNN常用,conv2d的输入必须是四维    
        train1 = train1.unsqueeze(0)
        train0 = train0.unsqueeze(0)
        return  (train1,train0)

第四步:封装起来

def get_load(dir, batch_size,shuffle,num_workers):
    dataset_ = PET_dataset(dir)
    data_loader = DataLoader(dataset=dataset_,batch_size=batch_size,shuffle=shuffle,num_workers=num_workers)
    return data_loader

验证

data_loader = get_load(path,1,True,0)
def train():
    total_iters = 0
    for iter_, (x, y ) in enumerate(data_loader):
        total_iters += 1
        print(x.shape)
        print(y.shape)

train()

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Mario cai

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值