import torch
import numpy as np
from torch.utils.data import DataLoader
class Getloader(torch.utils.data.Dataset):
def __init__(self,data_root,data_label):
self.data = data_root
self.label = data_label
def __getitem__(self, index):
data = self.data[index]
labels = self.label[index]
return data , labels
def __len__(self):
return len(self.data)
这部分是读取数据集的模板,可以把他当成一个容器,这个要配合DataLoader使用发挥作用。
实例化:
source_data = '图片路径'
source_label= ’图片路径‘
data = Getloader(source_data,source_label)
DataLoader:
函数介绍:
torch.utils.data.DataLoader(dataset,batch_size,shuffle,drop_last,num_workers)
- dataset: 加载torch.utils.data.Dataset对象数据
- batch_size: 每个batch的大小
- shuffle:是否对数据进行打乱
- drop_last:是否对无法整除的最后一个datasize进行丢弃
- num_workers:表示加载的时候子进程数
实例化:
datas = DataLoader(data,batch_size=5,shuffle=True,drop_last=False)
查看实例化其结果:
for i, data in enumerate(datas):
# i表示第几个batch, data表示该batch对应的数据,包含data和对应的labels
print("第 {} 个Batch \n{}".format(i, data))
#这个是训练网络时调用的方法
for data , label in datas:
print(data)
print(label) #输出的是张量数组
这个DataLoader会返回Dataset里面的_getitem_中的return值,并且每次返回batch_size=5个值(这里的data),他会把这5个值放到一个序列里(列表、元组),这个5个data和5个label很好理解是下标对应的关系(故可以作为同一对训练集使用)。
当我们训练的数据集为图片时,dataset中_getitem_返回的要是张量,这时设置batch_size=1,则读取的就是单个的图片和对应的标签图片了。
注:用dataloader得到dataset里的图片自动由numpy转化为tensor,故这里不需要手动转化,而在预测的时候,你要手动转化一下
完整测试代码贴上去,大家可以试一下
import torch from torch.utils.data import DataLoader class Getloader(torch.utils.data.Dataset): def __init__(self,data_root,data_label): self.data = data_root self.label = data_label def __getitem__(self, index): data = self.data[index] labels = self.label[index] return data ,labels def __len__(self): return len(self.data) source_data = 'abcdefgh' source_label= [1,2,3,4,5,6,7,8] data = Getloader(source_data,source_label) datas = DataLoader(data,batch_size=3,shuffle=True,drop_last=False) for i ,da in datas: print('-----------') print(i) print(da)
这是结果:
D:\Anaconda\envs\pytorch3.7\python.exe D:/project/pytorch/try/main.py
-----------
('g', 'h', 'c')
tensor([7, 8, 3])
-----------
('d', 'f', 'a')
tensor([4, 6, 1])
-----------
('b', 'e')
tensor([2, 5])
进程已结束,退出代码为 0
再贴一个可以使用的例子:
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
class ISBI_Loader(Dataset):
def __init__(self, data_path):
# 初始化函数,读取所有data_path下的图片
self.data_path = data_path
self.imgs_path = glob.glob(os.path.join(data_path, 'image2/T*.png')) #得到一个list(遍历路径下的所有)
self.labs_path = glob.glob(os.path.join(data_path, 'label2/P*.jpg'))
def __getitem__(self, index):
# 根据index读取图片
image_path = self.imgs_path[index]
# 根据image_path生成label_path
label_path = self.labs_path[index]
# 读取训练图片和标签图片
image = cv2.imread(image_path)
label = cv2.imread(label_path)
return image, label
def __len__(self):
# 返回训练集大小
return len(self.imgs_path)