Pytorch数据集加载
Dataset
如果弄明白了pytorch
中dataset类,你可以创建适应任意模型的数据集接口。
所谓数据集,无非就是一组{x:y}的集合,你只需要在这个类里说明“有一组=={x:y}==的集合”就可以了。
对于图像分类任务,图像+分类
对于目标检测任务,图像+bbox、分类
对于超分辨率任务,低分辨率图像+超分辨率图像
对于文本分类任务,文本+分类
可以通过.txt文件加载
/home/muzhan/projects/dataset/test/250_04.png _0
/home/muzhan/projects/dataset/test/250_05.png _7
/home/muzhan/projects/dataset/test/250_06.png _3
/home/muzhan/projects/dataset/test/250_07.png _2
/home/muzhan/projects/dataset/test/250_08.png _2
/home/muzhan/projects/dataset/test/250_09.png _3
/home/muzhan/projects/dataset/test/250_10.png _4
/home/muzhan/projects/dataset/test/250_11.png _0
/home/muzhan/projects/dataset/test/250_12.png _9
重新定义自己的dataset类
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self, dataset_type, transform=None, update_dataset=False):
"""
dataset_type: ['train', 'test']
"""
dataset_path = '/home/muzhan/projects/dataset/'
if update_dataset:
make_txt_file(dataset_path) # update datalist
self.transform = transform
self.sample_list = list()
self.dataset_type = dataset_type
f = open(dataset_path + self.dataset_type + '/datalist.txt')
lines = f.readlines()
for line in lines:
self.sample_list.append(line.strip())
f.close()
def __getitem__(self, index):
item = self.sample_list[index]
# img = cv2.imread(item.split(' _')[0])
img = Image.open(item.split(' _')[0])
if self.transform is not None:
img = self.transform(img)
label = int(item.split(' _')[-1])
return img, label
def __len__(self):
return len(self.sample_list)
Dataloader
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
dataset:定义的dataset类返回的结果。
batchsize:每个bacth要加载的样本数,默认为1。
shuffle:在每个epoch中对整个数据集data进行shuffle重排,默认为False。
sample:定义从数据集中加载数据所采用的策略,如果指定的话,shuffle必须为False;batch_sample类似,表示一次返回一个batch的index。
num_workers:表示开启多少个线程数去加载你的数据,默认为0,代表只使用主进程。
collate_fn:表示合并样本列表以形成小批量的Tensor对象。
pin_memory:表示要将load进来的数据是否要拷贝到pin_memory区中,其表示生成的Tensor数据是属于内存中的锁页内存区,这样将Tensor数据转义到GPU中速度就会快一些,默认为False。
drop_last:当你的整个数据长度不能够整除你的batchsize,选择是否要丢弃最后一个不完整的batch,默认为False。
enumerate()
enumerate() 函数用于将一个可遍历的数据对象(如列表、元组或字符串)组合为一个索引序列,同时列出数据和数据下标,一般用在 for 循环当中。
遍历数据和标签
>>> import torch
>>> batch_data = torch.randn(10)
>>> batch_data
# tensor([-1.4227, 0.4803, -0.1308, -0.9972, -1.2646, -0.7575, -0.6185, 0.3919,
-0.9820, -0.1905])
>>> labels = torch.linspace(1,10,10)
>>> labels
# tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
>>> for i, (data, labels) in enumerate(zip(batch_data, labels)):
print("{}: {},{}".format(i, data, labels))
# 0: -1.4227263927459717,1.0
1: 0.48032230138778687,2.0
2: -0.13082626461982727,3.0
3: -0.9972370266914368,4.0
4: -1.2645894289016724,5.0
5: -0.7574924230575562,6.0
6: -0.6185144782066345,7.0
7: 0.39187055826187134,8.0
8: -0.9819689989089966,9.0
9: -0.19045710563659668,10.0
一般图片来说输入是
B x C x H x W 分别是 批量, 通道,高,宽
输出是
B x num_classes
训练过程中可用VISDOM进行可视化
注意充github源码进行安装不然可能失败