文章目录
一.数据处理箱概述
pytorch数据处理只要是torch.utils.data与torchvision模块。
1.torch.utils.data
其常用的主要为以下四大类:
1)Dataset:一个抽象类,自定义数据集需要继承这个类,并重写两个函数,一个为__len__,另一个为__getitem__,前者提供数据大小,后者通过制定索引获取数据和标签,但每次只能获取一个数据。
2)DataLoader:实现批量处理读取数据,并提供打乱数据,提供并行加速等功能。
2.torchvision
其是pytorch的一个视觉处理工具包,独立于Pytorch。
其主要包含以下四类:
1)datasets:提供常用的数据集加载,里面含有许多数据集,如Minist等。
2)models:提供深度学习许多经典的网络模型以及训练好的模型(选择pretrained=True)
3)transforms:数据预处理操作,主要针对张量(Tensor)和图片(PIL Image)对象的操作。
4)utils:含有两个函数,一个为make_grid,他能将多张图片拼接在一个网格中;另一个是save_img,他能将Tensor保存为图片。
二. 具体实例
1.通过torch.utils.Dataset构建自定义数据集
mport torch
from torch.utils.data import Dataset,DataLoader
"""继承Dataset类"""
class Test(Dataset):
def __init__(self):
self.data=torch.tensor([[1,2],[3,4],[5,6],[7,8]])
self.label=torch.tensor([1,2,3,4])
def __getitem__(self,index):
return self.data[index],self.label[index]
def __len__(self):
return len(self.data)
My=Test()
length=My.__len__()
for i in range(length):
print(My[i])
运行结果:
2.通过DataLoader对数据进行批处理
DataLoader模块常用格式:
data_loader=data.DataLoader(dataset,batch_size,shuffle)
参数说明:
dataset:数据集
batch_size:批处理大小
shuffle:是否将数据打乱,打乱为True,默认为False
data_loader:返回值,类型为<class 'torch.utils.data.dataloader.DataLoader'>,但可以类似与像元组那样获取数据
import torch
from torch.utils.data import Dataset,DataLoader
"""继承Dataset类"""
class Test(Dataset):
def __init__(self):
self.data=torch.tensor([[1,2],[3,4],[5,6],[7,8]])
self.label=torch.tensor([1,2,3,4])
def __getitem__(self,index):
return self.data[index],self.label[index]
def __len__(self):
return len(self.data)
My=Test()
data_loader=DataLoader(My,batch_size=2,shuffle=True)
for i,loader in enumerate(data_loader):
data,label=loader
print("index:",i)
print("data:",data,"label:",label)
运行结果:
3.transforms对张量和图像的常见操作
常见用法见官网
官网地址:
https://pytorch.org/docs/stable/torchvision/transforms.html#transforms-on-pil-image
4.ImageFolder
当我们的图片依据其标签分别处于不同文件时,类似于下面:
每个文件中的图片属于同一标签。
我们此时可以利用torchvision.datasets.ImageFolder来直接构造dataset,
常用代码格式如下:
dataset=torchvision.datasets.ImageFolder(path,transform)
参数说明:
path:文件路径
transform:在构建数据集的时候对图片的一些操作,如将图片裁剪成统一尺寸,将图片转换为张量等,可以通过torchvision.transforms来实现。
注意:图片转换为张量时格式由原来的(H,W,C)变为了(C,H,W)
5.torchvision.utils模块
6.实例
import torch
from torchvision import transforms,utils
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
"""构建数据集"""
"""对数据预处理操作"""
my_trans=transforms.Compose([transforms.CenterCrop(256),transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor()])
dataset=ImageFolder("F:/2020/pytorch/inkCode-pytorch_tutorial-master/pytorch_tutorial/028_034:图像识别核心模块实战解读/卷积网络实战/flower_data/valid"
,transform=my_trans)
data_loader=DataLoader(dataset,batch_size=20,shuffle=True)
"""查看数据集"""
for index,loader in enumerate(data_loader):
data,label=loader
if index==0:
print("index:",index)
"""将多张图片拼接:"""
grid=utils.make_grid(data)
"""将多张照片拼成一个网格"""
plt.imshow(grid.numpy().transpose(1,2,0))
plt.show()
"""将张量保存为图片"""
utils.save_image(grid,"test01.jpg")
运行结果: