前提:
本文的记录前提是---有一个完整、已调通的pytorch网络项目,因为暂时比赛要用,完整项目等过一段时间再打包发到github上...
比如:加载的pytorch自带cifar数据集:
# train、test图像预处理和增强
transform_train = transforms.Compose(
[transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
#加载train、test数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
数据预处理torchvision.transforms这一部分主要是进行数据的中心化(torchvision.transforms.CenterCrop)、随机剪切(torchvision.transforms.RandomCrop)、正则化、图片变为Tensor、tensor变为图片等。如果不懂请查看官方文档:TORCHVISION.TRANSFORMS
构建Dataset子类
那么如果我想将cifar数据集换成我自己的数据集怎么办呢?答案是:如果想要使用自己的数据,则必须自己构建一个torch.utils.data.Dataset的子类去读取数据。例如:
from __future__ import print_function
import torch.utils.data as data
import torch
class MyDataset(data.Dataset):
def __init__(self, images, labels):#这一部分用于读取训练、测试数据
self.images = images
self.labels = labels
def __getitem__(self, index):#这一部分将读取的数据输出,返回的是tensor格式
img, target = self.images[index], self.labels[index]
return img, target
def __len__(self):
return len(self.images)
dataset = MyDataset(images, labels)
查看torchvision.datasets.CIFAR10的源码也可以看到cifar10也是继承了Dataset这个类
在定义torch.utils.data.Dataset的子类时,必须重载的两个函数是__len__和__getitem__。__len__返回数据集的大小,__getitem__实现数据集的下标索引,返回对应的图像和标记(不一定非得返回图像和标记,返回元组的长度可以是任意长,这由网络需要的数据决定)。
在创建DataLoader时会判断__getitem__返回值的数据类型,然后用不同的if/else分支把数据转换成tensor,所以,_getitem_返回值的数据类型可选择范围很多,一种可以选择的数据类型是:图像为numpy.array,标记为int数据类型。
实例:
比如这里我需要读入我自己的32*32大小的图像数据,则代码为:
# 图像预处理和增强
transform_train = transforms.Compose(
[transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
transform_test = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ])
#加载train、test数据集
trainset = MyDataset(imgdir='./Train',imgpath='./Train.txt', transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
testset = MyDataset(imgdir='./Test',imgpath='./Test.txt', train=False, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=False, num_workers=2)
这里将参数说明一下:MyDataset中imgdir为图片的存放位置,imgpath中包含每张图片的name以及图片的label,其它同cifar
torch.utils.data.DataLoader()
函数,合成数据并且提供迭代访问。主要由两部分组成:
- dataset(Dataset)。输入加载的数据,就是上面的MyDataset的实现。
- batch_size, shuffle, sampler, batch_sampler, num_worker, collate_fn, pin_memory, drop_last, timeout等参数,介绍几个比较常用的,这些在官方网站都有:
- batch-size。样本每个batch的大小,默认为1。
- shuffle。是否打乱数据,默认为False。
- num_workers。数据分为几个线程处理默认为0。
- sampler。定义一个方法来绘制样本数据,如果定义该方法,则不能使用shuffle。默认为False
之后就是在MyDataset里读取和输出自定义数据集: 简而言之就是在继承了Dataset类的Mydataset里,__init__函数中读取图像数据,DataLoader再通过
__getitem__中获取输出
import os
import cv2
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, imgdir, imgpath,train=True,
transform=None, target_transform=None):
self.root = os.path.expanduser(imgdir)
self.transform = transform
self.target_transform = target_transform
self.train = train # training set or test set
# now load the picked numpy arrays
if self.train:
self.train_data = []#images
self.train_labels = []#labels
#read the images and labels in file Train.txt or Test.txt
with open(imgpath,"r") as imgpath:
for line in imgpath:
line=line.split(' ')
image=Image.open(line[0])
image = np.array(image)
self.train_data.append(image)#将读取的图片放入train_data list里
self.train_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
imgpath.close()
self.train_data = np.array(self.train_data)
self.train_data = self.train_data.reshape((1000,3, 32, 32))
self.train_data = self.train_data.transpose((0, 2, 3, 1)) # convert to HWC
else:
self.test_data = []#images
self.test_labels = []#labels
#read the images and labels in file Train.txt or Test.txt
with open(imgpath,"r") as imgpath:
for line in imgpath:
line = line.split(' ')
image=Image.open(line[0])
image = np.array(image)
self.test_data.append(image)#将读取的图片放入train_data list里
self.test_labels.append(int(line[1]))#将读取图片的对应label放入train_labels里
imgpath.close()
self.test_data = np.array(self.test_data)
self.test_data = self.test_data.reshape((200, 3,32, 32))
self.test_data = self.test_data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
if self.train:
img, target = self.train_data[index], self.train_labels[index]
else:
img, target = self.test_data[index], self.test_labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img,target
def __len__(self):
if self.train:
return len(self.train_data)
else:
return len(self.test_data)
参考:
https://blog.csdn.net/GYGuo95/article/details/78821520
https://blog.csdn.net/victoriaw/article/details/72356453
https://pytorch.org/docs/stable/torchvision/transforms.html
【房屋出租】通州北京宇涵文化创意园https://www.douban.com/note/763129847/