(第二篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

前言:前面的系列文章之第一篇已经基本上说明了DataSet类和DataLoader类的用法,但是鉴于DataLoader类中有一个参数collate_fn使用起来比较复杂,所以本次的第二篇文章还专门说一下这个函数的功能。第一篇文章请参考:

(第一篇)pytorch数据预处理三剑客之——Dataset,DataLoader,Transform

collate_fn,中单词collate的含义是:核对,校勘,对照,整理。顾名思义,这就是一个对每一组样本数据进行一遍“核对和重新整理”,现在可能更好理解一些。

一、本次案例

本次为了更加方便的演示整个过程,假设有20组训练样本,输入的样本x是【1,2,3,4,,,,18,19,20】,

输出的标签y是 【100,200,300,,,,1800,1900,2000】

现在我不是用collate_fn参数,我要将数据随机打乱,并且batch_size等于3,简单的实现如下:

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader

x=range(1,21,1) 
y=range(100,2100,100)

class XYDataSet(Dataset):
    def __init__(self,x,y):
        self.x_list=x
        self.y_list=y
        assert len(self.x_list)==len(self.y_list)

    def __len__(self):
        return len(self.x_list)

    def __getitem__(self,index):
        x_one=self.x_list[index]
        y_one=self.y_list[index]
        return (x_one,y_one)

# 第一步:构造dataset
dataset=XYDataSet(x,y)
# 第二步:构造dataloader
dataloader=DataLoader(dataset,batch_size=3,shuffle=True)

# 第三步:对dataloader进行迭代
for epoch in range(2): # 只查看两个epoch
    for x_train,y_train in dataloader:
        print(x_train)
        print(y_train)
        print("-----------------------------------")
'''
tensor([ 3, 12, 16])
tensor([ 300, 1200, 1600])
-----------------------------------
tensor([13, 15,  9])
tensor([1300, 1500,  900])
-----------------------------------
tensor([ 8, 19,  7])
tensor([ 800, 1900,  700])
-----------------------------------
tensor([18,  1, 14])
tensor([1800,  100, 1400])
-----------------------------------
tensor([ 2,  5, 20])
tensor([ 200,  500, 2000])
-----------------------------------
tensor([11, 17,  6])
tensor([1100, 1700,  600])
-----------------------------------
tensor([10,  4])
tensor([1000,  400])
-----------------------------------    # 这里是第一个epoch结束了,会进行一次混洗
tensor([14,  2,  3])
tensor([1400,  200,  300])
-----------------------------------
tensor([11,  1, 15])
tensor([1100,  100, 1500])
-----------------------------------
tensor([10,  6, 20])
tensor([1000,  600, 2000])
-----------------------------------
tensor([13,  5,  8])
tensor([1300,  500,  800])
-----------------------------------
tensor([18, 19, 12])
tensor([1800, 1900, 1200])
-----------------------------------
tensor([ 4, 16,  7])
tensor([ 400, 1600,  700])
-----------------------------------
tensor([ 9, 17])
tensor([ 900, 1700])
-----------------------------------
'''

但是现在有一个问题,我希望对于原来的样本数据重新你处理一下,将每一组样本的x加上0.5,将每一组样本的y加上50,然后重新组成样本,我当然可以这么做,即在定义DataSet的__getitem__里面去实现,只需要简单的更改__getitem__即可,如下:

def __getitem__(self,index):
        x_one=self.x_list[index]+0.5  # 每一个x加上0.5
        y_one=self.y_list[index]+50   # 每一个y加上50
        return x_one,y_one

二、通过自定义collate_fn函数来实现

这里整个DataSet的实现完全不变,定义的函数如下:

import numpy as np
import torch
from torch.utils.data import Dataset,DataLoader
from torch.utils.data.dataloader import default_collate  # 导入这个函数,这个函数其实就是pytorch默认给这个collate_fn的默认实现

def collate_fn(batch):
    """
    batch :是一个列表,列表的长度是 batch_size
           列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
    """
    new_batch=[]
    for index in range(len(batch)):
        x_=batch[index][0]+0.5  # 每一个样本x加上0.5
        y_=batch[index][1]+50   # 没一个样本y加上50

        new_batch.append((x_,y_))  # 将改变之后的x,y重新组成一个batch

    return default_collate(new_batch)
# 第一步:构造dataset
dataset=XYDataSet(x,y)
# 第二步:构造dataloader,这里需要传递自定义的collate_fn函数
dataloader=DataLoader(dataset,batch_size=3,shuffle=True,collate_fn=collate_fn)

# 第三步:对dataloader进行迭代
for epoch in range(2): # 只查看两个epoch
    for x_train,y_train in dataloader:
        print(x_train)
        print(y_train)
        print("-----------------------------------")
'''
tensor([18.5000,  1.5000, 20.5000], dtype=torch.float64)
tensor([1850,  150, 2050])
-----------------------------------
tensor([19.5000, 14.5000,  4.5000], dtype=torch.float64)
tensor([1950, 1450,  450])
-----------------------------------
tensor([6.5000, 2.5000, 5.5000], dtype=torch.float64)
tensor([650, 250, 550])
-----------------------------------
tensor([17.5000,  9.5000, 15.5000], dtype=torch.float64)
tensor([1750,  950, 1550])
-----------------------------------
tensor([12.5000,  8.5000,  7.5000], dtype=torch.float64)
tensor([1250,  850,  750])
-----------------------------------
tensor([ 3.5000, 13.5000, 16.5000], dtype=torch.float64)
tensor([ 350, 1350, 1650])
-----------------------------------
tensor([11.5000, 10.5000], dtype=torch.float64)
tensor([1150, 1050])
-----------------------------------
tensor([ 2.5000,  7.5000, 18.5000], dtype=torch.float64)
tensor([ 250,  750, 1850])
-----------------------------------
tensor([14.5000, 13.5000,  9.5000], dtype=torch.float64)
tensor([1450, 1350,  950])
-----------------------------------
tensor([10.5000, 12.5000, 17.5000], dtype=torch.float64)
tensor([1050, 1250, 1750])
-----------------------------------
tensor([20.5000, 15.5000,  8.5000], dtype=torch.float64)
tensor([2050, 1550,  850])
-----------------------------------
tensor([ 6.5000, 11.5000, 19.5000], dtype=torch.float64)
tensor([ 650, 1150, 1950])
-----------------------------------
tensor([5.5000, 1.5000, 4.5000], dtype=torch.float64)
tensor([550, 150, 450])
-----------------------------------
tensor([16.5000,  3.5000], dtype=torch.float64)
tensor([1650,  350])
-----------------------------------
'''

三、collate_fn函数的一般定义格式

从上面的例子中可以更加清楚的理解“collate的校对、整理”的含义,这个函数自定义实现的时候有一个大致的模板:

from torch.utils.data.dataloader import default_collate  # 导入这个函数

def collate_fn(batch):
    """
    params:
        batch :是一个列表,列表的长度是 batch_size
               列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
               大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
    returns:
        整理之后的新的batch
    """
     
    # 这一部分是对 batch 进行重新 “校对、整理”的代码

    return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。

 

  • 23
    点赞
  • 52
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
PyTorch中,数据预处理通常涉及以下几个步骤: 1. 加载数据集:使用PyTorch的数据加载器(如`torchvision.datasets`)加载数据集。可以是常见的图像数据集(如MNIST、CIFAR10)或自定义数据集。 2. 转换数据:使用`torchvision.transforms`模块中的转换函数对数据进行预处理。常见的转换包括缩放、裁剪、旋转、归一化等。可以根据需求组合多个转换操作。 3. 创建数据加载器:将转换后的数据集传递给`torch.utils.data.DataLoader`来创建一个数据加载器。数据加载器可以指定批处理大小、并发加载等参数。 下面是一个简单的示例,演示如何使用PyTorch进行数据预处理: ```python import torch import torchvision import torchvision.transforms as transforms # 1. 加载数据集 train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True) # 2. 转换数据 transform = transforms.Compose([ transforms.ToTensor(), # 将图像转换为Tensor transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1, 1]范围 ]) train_dataset = train_dataset.transform(transform) # 3. 创建数据加载器 train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) ``` 在这个示例中,我们加载了MNIST数据集,并将图像转换为Tensor,并进行了归一化处理。然后使用`DataLoader`创建了一个批处理大小为64的数据加载器,同时打乱了数据的顺序。 这只是一个简单的例子,根据具体需求,你可能需要进行更复杂的数据预处理操作。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值