Pytorch与深度学习自查手册2-数据加载和预处理

Pytorch与深度学习自查手册2-数据加载和预处理

数据加载

DataSet类

自定义一个继承 Dataset类的类 ,需要重写以下三个函数:

  1. __init__:传入数据,或者像下面一样直接在函数里加载数据;
  2. __len__:返回这个数据集一共有多少个item;
  3. __getitem__:返回一条训练数据,并将其转换成tensor。
  4. 通常还会在其中增加一个collate_fn函数,用于DataLoader,使用这个参数可以自己操作每个batch的数据,比如说在自然语言处理的命名实体识别任务中,在该函数中对每个batch中的样本都padding到同一长度等。
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
    def __init__(self,path):
        #加载数据
        a = np.load("a.npy",allow_pickle=True)
        b = np.load("b.npy",allow_pickle=True)
        d = np.load("d.npy",allow_pickle=True)
        c = np.load("c.npy")
        self.x = list(zip(a,b,d,c))
        self.y = ...
    def __getitem__(self, idx):
        assert idx < len(self.x)
        return self.x[idx],self.y[idx]
    def __len__(self):
        return len(self.x)
    def collate_fn(self,batch):
    	#……
        pass
collate_fn:如何取样本

一般的,默认的collate_fn函数是要求一个batch中的图片都具有相同size(因为要做stack操作),当一个batch中的图片大小都不同时,可以使用自定义的collate_fn函数,则一个batch中的图片不再被stack操作,可以全部存储在一个list中,当然还有对应的label

# a simple custom collate function, just to show the idea 
def my_collate(batch):    
    data = [item[0] for item in batch]    
    target = [item[1] for item in batch]    
    target = torch.LongTensor(target)    
    return [data, target]

DataLoader类

DataLoader包括三个参数:

  1. dataset:传入的数据;
  2. shuffle = True:是否打乱数据;
  3. collate_fn函数:使用这个参数可以自己操作每个batch的数据。
  4. drop_last:告诉如何处理划分batch后剩下的最后不足一个batch的样本集合,True就抛弃,否则保留。
from torch.utils.data import DataLoader
dataset = Mydata()
#构建DataLoader
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = dataset.collate_fn)

从DataLoader中取样本

#从dataloader中逐一取样本
train_features, train_labels = next(iter(train_dataloader)) 
#循环取样本
for X, y in dataloader: 
    print("Shape of X [N, C, H, W]: ", X.shape)
    print("Shape of y: ", y.shape, y.dtype)
    break

数据预处理

自定义collate_fn函数传入

自定义collate_fn函数传入DataLoader。

transforms:对图片进行变换

PyTorch 学习笔记(三):transforms的二十二个方法

transforms.ToTensor()转化数据类型
from torchvision import transforms, utils
from PIL import Image
img=Image.image(img_path)
tensor_trans=transforms.ToTensor()
img_tensor=tensor_trans(img)
合并数据处理过程transforms.Compose()
trans_compose=transforms.Compose([transforms.Resize(),transforms.ToTensor()])
trans_compose(img)
正则化

transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

transforms.Resize(image_size)

随机裁剪/中心裁剪

transforms. RandomCrop((512,1000))

transforms.CenterCrop(image_size)

整数转one-hot
num_class=10
target_transform = Lambda(lambda y: torch.zeros(
    num_class, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

参考资料

图神经网络的下游任务3-图分类 | 冬于的博客 (ifwind.github.io)

PyTorch 学习笔记(三):transforms的二十二个方法

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值