p y t o r c h 中 的 D a t a L o a d e r 与 D a t a S e t pytorch中的DataLoader与DataSet pytorch中的DataLoader与DataSet
class torch.utils.data.Dataset
决定数据从哪读取,如何读取,进行何种预处理
表示Dataset的抽象类。
所有子类应该override __len__
和__getitem__
,前者提供了数据集的大小
,后者支持整数索引
,范围从0到len(self)。
LoadDataset(data_dir=major_config.train_image, transform=train_transform)
import os
import random
from PIL import Image
from torch.utils.data import Dataset
import major_config
random.seed(1)
# 类别对应表
dict_label = major_config.dict_label
# 返回所有图片路径和标签
def get_img_label(data_dir):
img_label_list = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
# img_names = list(filter(lambda x: x.endswith('.png'), img_names)) # 如果改了图片格式,这里需要修改
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = dict_label[sub_dir]
img_label_list.append((path_img, int(label)))
return img_label_list
# 主要是用来接受索引返回样本用的
class LoadDataset(Dataset):
def __init__(self, data_dir, transform=None):
# 获取所有图片的路径、label , 和 确定预处理操作
self.img_label_list = get_img_label(data_dir) # img_label_list,在DataLoader中通过index读取样本
self.transform = transform
#接受一个索引,返回一个样本 --- img, label
def __getitem__(self, index):
path_img, label = self.img_label_list[index]
img = Image.open(path_img).convert('RGB') # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.img_label_list)
__getitem__
的主要作用
主要是用来接受索引返回样本用的。(Sample :Index 生成索引)
__len__
的主要作用
__getitem__
接受索引的范围就是__len__
里确定的范围
class torch.utils.data.DataLoader
class torch.utils.data.DataLoader( dataset,
batch_size=1,
shuffle=False,
sampler=None,
num_workers=0,
collate_fn=<function default_collate>,
pin_memory=False,
drop_last=False
)
dataset (Dataset)
– 加载数据的数据集。batch_size (int, optional)
– 每个batch加载多少个样本(默认: 1)。shuffle (bool, optional)
– 设置为True时会在每个epoch重新打乱数据(默认: False).sampler (Sampler, optional)
– 定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数。num_workers (int, optional)
– 用多少个子进程加载数据。0表示数据将在主进程中加载(默认: 0)drop_last (bool, optional)
– 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)。即当样本数不能被batchsize整除时,是否舍弃最后一批数据
_SingleProcessDataLoaderIter
def _next_data(self):
_sampler_iter
sampler
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def _next_data(self):
index = self._next_index() # may raise StopIteration
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
if self._pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data