在训练模型前,最重要的部分就是制作好数据集,有些情况下,由于图片数据过多,然后存储很不方便,我们就需要将数据制作成npy类型的数据格式。npy数据格式是一个四维的数组[N,H,W, C],其中N代表数据集的总数,H, W,C分别代表每一张图片对应的长、宽、以及通道数。
数据制作好之后,就是如何加载数据问题,TF中加载数据相对比较容易,但是Pytorch中,我们一般都是将数据制作成dataset,再传入Dataloader进行加载,因此就需要继承Dataset的类,然后编写读取npy的数据格式。Dataset中,我们需要定义三个函数。
一、__init__(self,data) 函数
主要是用来加载npy数据的,也可以加载数据预处理的函数,比如将数据转化为tensor之类的操作
def __init__(self, data):
self.data = np.load(data) #加载npy数据
self.transforms = transform #转为tensor形式
二、__len__(self)函数
这个函数就是用来返回数据的总个数
def __len__(self):
return self.data.shape[0] #返回数据的总个数
三、 __getitem__(self,index)函数
这个是最要的函数,类似一个for循环,从头开始,每次读取一个保存在npy里面的数据,然后进行处理后,可以同时返回训练数据,以及对应的标签
def __getitem__(self, index):
hdct= self.data[index, :, :, :] # 读取每一个npy的数据
hdct = np.squeeze(hdct) # 删掉一维的数据,就是把通道数这个维度删除
ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声
hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式
ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式
hdct= self.transforms(hdct) #转为tensor形式
ldct= self.transforms(ldct) #转为tensor形式
return ldct,hdct #返回数据还有标签
完整的代码如下:
import torch
import numpy as np
import skimage
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(1) # reproducible
transform = transforms.Compose([
transforms.ToTensor(), # 将图片转换为Tensor,归一化至[0,1]
])
'''NPY数据格式'''
class MyDataset(Dataset):
def __init__(self, data):
self.data = np.load(data) #加载npy数据
self.transforms = transform #转为tensor形式
def __getitem__(self, index):
hdct= self.data[index, :, :, :] # 读取每一个npy的数据
hdct = np.squeeze(hdct) # 删掉一维的数据,就是把通道数这个维度删除
ldct = 2.5 * skimage.util.random_noise(hdct * (0.4 / 255), mode='poisson', seed=None) * 255 #加poisson噪声
hdct=Image.fromarray(np.uint8(hdct)) #转成image的形式
ldct=Image.fromarray(np.uint8(ldct)) #转成image的形式
hdct= self.transforms(hdct) #转为tensor形式
ldct= self.transforms(ldct) #转为tensor形式
return ldct,hdct #返回数据还有标签
def __len__(self):
return self.data.shape[0] #返回数据的总个数
def main():
dataset=MyDataset('.\data_npy\img_covid_poisson_glay_clean_BATCH_64_PATS_100.npy')
data= DataLoader(dataset, batch_size=64, shuffle=True, pin_memory=True)
if __name__ == '__main__':
main()