Pytorch与深度学习自查手册2-数据加载和预处理
数据加载
DataSet类
自定义一个继承 Dataset类的类 ,需要重写以下三个函数:
__init__
:传入数据,或者像下面一样直接在函数里加载数据;__len__
:返回这个数据集一共有多少个item;__getitem__
:返回一条训练数据,并将其转换成tensor。- 通常还会在其中增加一个
collate_fn函数
,用于DataLoader
,使用这个参数可以自己操作每个batch的数据,比如说在自然语言处理的命名实体识别任务中,在该函数中对每个batch中的样本都padding
到同一长度等。
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self,path):
#加载数据
a = np.load("a.npy",allow_pickle=True)
b = np.load("b.npy",allow_pickle=True)
d = np.load("d.npy",allow_pickle=True)
c = np.load("c.npy")
self.x = list(zip(a,b,d,c))
self.y = ...
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx],self.y[idx]
def __len__(self):
return len(self.x)
def collate_fn(self,batch):
#……
pass
collate_fn:如何取样本
一般的,默认的collate_fn
函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn
函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label
# a simple custom collate function, just to show the idea
def my_collate(batch):
data = [item[0] for item in batch]
target = [item[1] for item in batch]
target = torch.LongTensor(target)
return [data, target]
DataLoader类
DataLoader包括三个参数:
dataset
:传入的数据;shuffle
= True:是否打乱数据;collate_fn
函数:使用这个参数可以自己操作每个batch的数据。drop_last
:告诉如何处理划分batch后剩下的最后不足一个batch的样本集合,True就抛弃,否则保留。
from torch.utils.data import DataLoader
dataset = Mydata()
#构建DataLoader
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = dataset.collate_fn)
从DataLoader中取样本
#从dataloader中逐一取样本
train_features, train_labels = next(iter(train_dataloader))
#循环取样本
for X, y in dataloader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break
数据预处理
自定义collate_fn函数传入
自定义collate_fn
函数传入DataLoader。
transforms:对图片进行变换
PyTorch 学习笔记(三):transforms的二十二个方法
transforms.ToTensor()转化数据类型
from torchvision import transforms, utils
from PIL import Image
img=Image.image(img_path)
tensor_trans=transforms.ToTensor()
img_tensor=tensor_trans(img)
合并数据处理过程transforms.Compose()
trans_compose=transforms.Compose([transforms.Resize(),transforms.ToTensor()])
trans_compose(img)
正则化
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Resize(image_size)
随机裁剪/中心裁剪
transforms. RandomCrop((512,1000))
transforms.CenterCrop(image_size)
整数转one-hot
num_class=10
target_transform = Lambda(lambda y: torch.zeros(
num_class, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))