#%%
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import PIL.Image as Image
import matplotlib.pyplot as plt
from torchvision import transforms
#%% md
'''
数据量小
数据量小的时候,没有大问题,直接加载到内存。比如我们利用一些数据做的线性回归
数据量大
数据量大的时候,将所有的数据读取到内存中训练就会内存不够。而大数据量是非常常见的现象。
思路:Dataloader的处理逻辑是先通过Dataset类里面的 __getitem__ 函数获取单个的数据,然后组合成batch,
再使用collate_fn所指定的函数对这个batch做一些操作,比如padding啊之类的。如果没有使用collate_fn,默认就是基本的操作。
实现:
Dataset类:Pytorch读取图片,主要是通过Dataset类,Dataset类作为所有的datasets的基类存在,所有的datasets都需要继承它,类似于C++中的虚基类
1、需要继承Dataset类
2、需要实现__getitem__(self, index),及__len__(self)方法。__getitem__方法输入一个index(通常指图片数据的路径和标签信息),输出图片数据和标签;
__len__方法返回数据集的大小
说明:
1、Dataset类及其子类是迭代器
2、DataLoader类是迭代器
'''
#%% 肝脏数据类
def make_dataset(root):
imgsPath=[]
n=len(os.listdir(root))//2
for i in range(n):
img=os.path.join(root, "%03d.png"%i)
mask=os.path.join(root,"%03d_mask.png"%i)
imgsPath.append((img,mask))
return imgsPath
class LiverDataset(Dataset):
# 初始化的目的是获取所有的图像或者所有图像的索引,方便__getitem__读取
def __init__(self,path):
imgsPath = make_dataset(path) #"data/train",获取路径下的所有文件路径
self.imgsPath = imgsPath #存放着图片路径
def __getitem__(self, index):
x_path, y_path = self.imgsPath[index] #输入图像路径,标签图像路径
img_x = Image.open(x_path)
img_y = Image.open(y_path)
trans_x = transforms.ToTensor()
trans_y = transforms.ToTensor()
return trans_x(img_x), trans_y(img_y) #输入图像,标签图像(转换为tensor)
def __len__(self):
return len(self.imgsPath) #数据集的组数(1输入图像+1标签图像=1组)
#输入,1)训练集路径 2)输入图片处理方式 3)标签处理方式
batch_size = 2
path = r'G:\code\python\deeplearning\project-U-Net\data\train'
liver_dataset = LiverDataset(path)
dataloaders = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
#%% LiverDataset是一个迭代器
a = next(iter(liver_dataset))
t = transforms.ToPILImage()
plt.imshow(t(a[0]))
plt.show()
plt.imshow(t(a[1]))
plt.show()
#%% DataLoader是一个迭代器
x = dataloaders.__iter__().__next__()
t = transforms.ToPILImage()
plt.imshow(t(x[0][0,:,:,:]))
plt.show()
plt.imshow(t(x[1][0,:,:,:]))
plt.show()
# print(dataloaders.__len__())
# for x,y in dataloaders:
# x
# y
i=0,for x,y in dataloaders: 打散dataset数据索引,遍历整个dataset数据
i=1,for x,y in dataloaders:重新打散dataset数据索引,继续遍历dataset数据
for i in range(2):
for x,y in dataloaders:
print((x))
print((y))
情形2:DataLoader会依次读取 顺序/打乱的迭代器,当数据读取完再进行读取时候,raise StopIteration异常
from torch.utils.data.dataloader import DataLoader
loader = DataLoader(dataset=range(100),batch_size=1,shuffle=True)
data = iter(loader)
for i in range(101):
print(i,next(data))
StopIteration
可能会有疑问,为什么一般代码中写for step in training_data_loader: 执行,再下一个epoch的时候,却没有给出StopIteration异常!!!注意,通过查阅相关资料可以知道,在for循环(当前epoch)的过程中,都使用的同一个迭代器,使用next方法获取数据,当数据遍历一遍,迭代器销毁。下一个epoch时,使用的是另外一个迭代器,因此不会发生停止迭代异常。
情形3:batchSize >1 时,队列中最后一个数据不满足batchSize大小,输出个数则小于batchSize
from torch.utils.data.dataloader import DataLoader
loader = DataLoader(dataset=range(100),batch_size=3,shuffle=False)
data = iter(loader)
for i in range(101):
print(i,next(data))
输出:
......
30 tensor([90, 91, 92])
31 tensor([93, 94, 95])
32 tensor([96, 97, 98])
33 tensor([99])
文件夹中列出所有文件,并添加到自定义dataset类中
class tublinDataset_selfDefineTest(Dataset): def __init__(self,dirName): super(tublinDataset_selfDefineTest, self).__init__() if not os.path.isdir(dirName): raise ValueError('input file_path is not a dir') self.dirName = dirName # 获取路径下所有的图片名称,必须保证路径内没有图片以外的数据 self.image_list = os.listdir(self.dirName)