前言
pytorch对于怎么样把数据放进神经网络训练有一套非常成熟的机制,我们只需要按照流程即可,这个流程只要是涉及了Dataset、DataLoader和Transform
这篇博客参考了:
(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
(第二篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
(第三篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform
步骤
还是拿前一篇文章的例子pytorch系列教程(一)-训练和测试模型流程来讲述如何把数据放到神经网络中训练的
#前一篇文章的代码
train_datasets = MyDataset() # 第一步:构造Dataset对象
train_dataloader = DataLoader(train_datasets)# 第二步:通过DataLoader来构造迭代对象
以yolov1网络为例子,data/train.txt中的内容
下面来看看 MyDataset()中要做什么
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import numpy as np
from torchvision.transforms import ToTensor
from torchvision import datasets, transforms
import random
class MyDataset()(Dataset):
def __init__(self,transform):
self.transform = transform
self.img_list = []
self.labels=np.zeros((7,7,5*NUM_BBOX+len(CLASSES)))
#data/train.txt存放的是训练数据的路径
with open("data/train.txt", 'r') as f:
#将训练数据的路径放到img_list中
self.img_list = [x.strip() for x in f]
def __getitem__(self, idx):
###################################
#idx最大值为len(self.img_list)-1
img = cv2.imread(self.img_list[idx])
if self.transform:
#将数据转成tensor
img = self.transform(img)
label= self.transform(label)
return img,label
def __len__(self):
#返回训练数据的长度
return len(self.img_list)
def main():
transform=transforms.Compose([
transforms.ToTensor()
])
train_datasets = MyDataset(transform) # 第一步:构造Dataset对象
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)# 第二步:通过DataLoader来构造迭代对象
if __name__ == '__main__':
main()
总结
看完代码之后总结一下
1、构造一个类继承Dataset,例如 MyDataset(Dataset)。在类中
首先新建一个变量用来存放训练数据或者标签的路径,例如self.img_list
然后在类中重写 getitem(self, index)和__len__(self)
- getitem(self, index)中主要做的是返回的img和label。这个img和label都已经是tensor形式
- len(self)中主要做的是返回训练数据的总长度
Dataset详解
Dataset类是Pytorch中图像数据集中最为重要的一个类,也是Pytorch中所有数据集加载类中应该继承的父类。其中Dataset类中的两个私有成员函数必须被重载,否则将会触发错误提示:
- def getitem(self, index):
- def len(self):
- def init(self):
构造函数一般情况下我们也是要自己定义的,但是不是强制性的。
其中__len__应该返回数据集的大小,而__getitem__应该编写支持数据集索引的函数,例如通过dataset[i]可以得到数据集中的第i+1个数据。这个Dataset抽象父类的定义如下:
class Dataset(object):
def __getitem__(self, index):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
def __add__(self, other):
return ConcatDataset([self, other])
2、利用transforms.Compose添加图像变换
transforms中的图像变换操作大全
__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad",
"Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip",
"RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation",
"ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale"]
3、通过DataLoader来构造迭代对象
看一下DataLoader的定义
class DataLoader(object):
__initialized = False
def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
def __setattr__(self, attr, val):
def __iter__(self):
def __len__(self):
Arguments:
dataset (Dataset): 是一个DataSet对象,表示需要加载的数据集.
batch_size (int, optional): 每一个batch加载多少组样本,即指定batch_size,默认是 1
shuffle (bool, optional): 布尔值True或者是False ,表示每一个epoch之后是否对样本进行随机打乱,默认是False