Dataset与Dataloader

dataset

编写自己的数据集ReconDataset类,这个类将会继承torch中的data.Dataset类,重写init函数,getitem函数,len函数

class ReconDataset(data.Dataset):
    __inputdata = []
    __inputimg = []
    __outputdata = []
    a = []
    b = []
    c = []
    def __init__(self,root):
        self.__inputdata = []
        self.__outputdata = []
        self.__inputimg = []
        self.root = os.path.expanduser(root)
                    folder = root + "data_3d_tyz.mat"
        matdata = mat73.loadmat(folder)
            # tset
        for i in range(5):
           a = matdata['signal_small'][ :, :, :,i]
           c = matdata['image_small'][:,:, :, :,i]
           b = matdata['mri_small'][:, :, :,i]
           print(type(a))#<class 'numpy.ndarray'>
           self.__inputdata.append(a[np.newaxis, :, :,:])
           print(type(self.__inputdata))#<class 'list'>
           self.__inputimg.append(b[np.newaxis, :, :, :])
           self.__outputdata.append(c)
    def __getitem__(self, index):
        rawdata =  self.__inputdata[index] 
        reconstruction =self.__outputdata[index] 
        beamform = self.__inputimg[index]


        rawdata = torch.Tensor(rawdata)
        reconstructions = torch.Tensor(reconstruction)
        beamform = torch.Tensor(beamform)

        return rawdata, reconstructions,beamform

    def __len__(self):
        return len(self.__inputdata)

dataset里有三个列表,一个是__inputdata,一个是__inputimg,一个是__outputdata,列表里存的是np.array类型数据。当要取数据时,会调用getitem,取得列表中的单个或者多个数据,单个还是多个就是batch_size,这时候取得的数据就是tensor类型了。将numpy数组转换为PyTorch的Tensor是因为PyTorch是一个基于Tensor的计算框架,使用Tensor作为数据类型可以更好地与PyTorch的其它功能和操作兼容,比如自动求导。

用dataloader上传dataset

    mydataset = ReconDataset(dataset_pathr)

    train_loader = DataLoader(
        mydataset,
        batch_size=2, shuffle=True)
    batch_idx, (rawdata, reimage, bfimg) = list(enumerate(train_loader))[0]

一次取batch_size个数据,batch_idx是取的批次编号,训练的话就是 for batch_idx, (rawdata ,reimage,bfimg) in enumerate(train_loader):

    print('raw:',rawdata.size())#raw: torch.Size([2, 1, 9, 15, 16])
    print(type(rawdata))#<class 'torch.Tensor'>
    print('raw[0]:', rawdata[1].size())#raw[0]: torch.Size([1, 9, 15, 16])
    print('reimage:', reimage.size())#reimage: torch.Size([2, 3, 128, 128, 144])
    print('bfimg:', bfimg.size())#bfimg: torch.Size([2, 1, 128, 128, 144])
    print(mydataset.__len__())#5
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值