dataset
编写自己的数据集ReconDataset类,这个类将会继承torch中的data.Dataset类,重写init函数,getitem函数,len函数
class ReconDataset(data.Dataset):
__inputdata = []
__inputimg = []
__outputdata = []
a = []
b = []
c = []
def __init__(self,root):
self.__inputdata = []
self.__outputdata = []
self.__inputimg = []
self.root = os.path.expanduser(root)
folder = root + "data_3d_tyz.mat"
matdata = mat73.loadmat(folder)
# tset
for i in range(5):
a = matdata['signal_small'][ :, :, :,i]
c = matdata['image_small'][:,:, :, :,i]
b = matdata['mri_small'][:, :, :,i]
print(type(a))#<class 'numpy.ndarray'>
self.__inputdata.append(a[np.newaxis, :, :,:])
print(type(self.__inputdata))#<class 'list'>
self.__inputimg.append(b[np.newaxis, :, :, :])
self.__outputdata.append(c)
def __getitem__(self, index):
rawdata = self.__inputdata[index]
reconstruction =self.__outputdata[index]
beamform = self.__inputimg[index]
rawdata = torch.Tensor(rawdata)
reconstructions = torch.Tensor(reconstruction)
beamform = torch.Tensor(beamform)
return rawdata, reconstructions,beamform
def __len__(self):
return len(self.__inputdata)
dataset里有三个列表,一个是__inputdata,一个是__inputimg,一个是__outputdata,列表里存的是np.array类型数据。当要取数据时,会调用getitem,取得列表中的单个或者多个数据,单个还是多个就是batch_size,这时候取得的数据就是tensor类型了。将numpy数组转换为PyTorch的Tensor是因为PyTorch是一个基于Tensor的计算框架,使用Tensor作为数据类型可以更好地与PyTorch的其它功能和操作兼容,比如自动求导。
用dataloader上传dataset
mydataset = ReconDataset(dataset_pathr)
train_loader = DataLoader(
mydataset,
batch_size=2, shuffle=True)
batch_idx, (rawdata, reimage, bfimg) = list(enumerate(train_loader))[0]
一次取batch_size个数据,batch_idx是取的批次编号,训练的话就是 for batch_idx, (rawdata ,reimage,bfimg) in enumerate(train_loader):
print('raw:',rawdata.size())#raw: torch.Size([2, 1, 9, 15, 16])
print(type(rawdata))#<class 'torch.Tensor'>
print('raw[0]:', rawdata[1].size())#raw[0]: torch.Size([1, 9, 15, 16])
print('reimage:', reimage.size())#reimage: torch.Size([2, 3, 128, 128, 144])
print('bfimg:', bfimg.size())#bfimg: torch.Size([2, 1, 128, 128, 144])
print(mydataset.__len__())#5