PyTorch 2.1.4 数据的加载和预处理

PyTorch 基础 :数据的加载和预处理

PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。
并且torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用

# 首先要引入相关的包
import torch
#打印一下版本
torch.__version__
'1.12.1'

Dataset
Dataset是一个抽象类,为了能够方便的读取,需要将要使用的数据包装为Dataset类。
自定义的Dataset需要继承它并且实现两个成员方法:

  • getitem() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本

  • len() 该方法返回数据集的总长度

下面我们使用kaggle上的一个竞赛Blue Book for Bulldozers 预测推土机价格自定义一个数据集,该数据已经在和鲸社区上整理好了,为了方便介绍,我们使用里面的数据字典来做说明(因为条数少)

from torch.utils.data import Dataset
import pandas as pd
# 定义一个数据集
class BulldozerDataset(Dataset):
    '''数据集演示'''
    def __init__(self, csv_file):
        """实现初始化方法,在初始化的时候将数据读载入"""
        self.df=pd.read_csv(csv_file)
    def __len__(self):
        '''
        返回df的长度
        '''
        return len(self.df)
    def __getitem__(self, idx):
        '''
        根据 idx 返回一行数据
        '''
        return self.df.iloc[idx].SalePrice

至此,我们的数据集已经定义完成了,我们可以实例话一个对象访问他

ds_demo = BulldozerDataset(r'C:\Users\ADMIN\Desktop\Pytorch学习资料\median_benchmark.csv')
#实现了 __len__ 方法所以可以直接使用len获取数据总数
len(ds_demo)
11573
#用索引可以直接访问对应的数据,对应 __getitem__ 方法
ds_demo[0]
24000.0

Dataloader

DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小)、 shuffle(是否进行shuffle操作)、 num_workers(加载数据的时候使用几个子进程)。下面做一个简单的操作

dl = torch.utils.data.DataLoader(ds_demo, batch_size=10, shuffle=True, num_workers=0)

DataLoader返回的是一个可迭代对象,我们可以使用迭代器分次获取数据

idata = iter(dl)
print(next(idata))
tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)

常见的用法是使用for循环对其进行遍历

for i, data in enumerate(dl):
    print(i, data)
    break
0 tensor([24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000., 24000.,
        24000.], dtype=torch.float64)

我们已经可以通过dataset定义数据集,并使用Datalorder载入和遍历数据集,除了这些以外,PyTorch还提供能torcvision的计算机视觉扩展包,里面封装了

torchvision 包

torchvision 是PyTorch中专门用来处理图像的库,PyTorch官网的安装教程中最后的pip install torchvision 就是安装这个包。

torchvision.datasets

torchvision.datasets 可以理解为PyTorch团队自定义的dataset,这些dataset帮我们提前处理好了很多的图片数据集,我们拿来就可以直接使用:

MNIST
COCO
Captions
Detection
LSUN
ImageFolder
Imagenet-12
CIFAR
STL10
SVHN
PhotoTour

我们可以直接使用,示例如下:

import torchvision.datasets as datasets
train_set = datasets.MNIST(root='./data',# 表示 MNIST 数据的加载的目录
                          train=True, # 表示是否加载数据库的训练集,false的时候加载测试集
                          download=True,# 表示是否自动下载 MNIST 数据集
                          transform=None) # 表示是否需要对数据进行预处理,none为不进行预处理
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 502: Bad Gateway

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz



  0%|          | 0/9912422 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz



  0%|          | 0/28881 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 502: Bad Gateway

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz



  0%|          | 0/1648877 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz



  0%|          | 0/4542 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

torchvision.models

torchvision不仅提供了常用图片数据集,还提供了训练好的模型,可以加载之后,直接使用,或者在进行迁移学习
torchvision.models模块的 子模块中包含以下模型结构。

AlexNet
VGG
ResNet
SqueezeNet
DenseNetm
#我们直接可以使用训练好的模型,当然这个与datasets相同,都是需要从服务器下载的
import torchvision.models as models
vgg16 = models.vgg16(pretrained=True)

torchvision.transforms

transforms 模块提供了一般的图像转换操作类,用作数据处理和数据增强

import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),  #先四周填充0,在把图像随机裁剪成32*32
    transforms.RandomHorizontalFlip(),  #图像一半的概率翻转,一半的概率不翻转
    transforms.RandomRotation((-45,45)), #随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.229, 0.224, 0.225)), #R,G,B每层的归一化用到的均值和方差
])

关于(0.485, 0.456, 0.406), (0.2023, 0.1994, 0.2010)

官方的这个帖子有详细的说明:
https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457/21
这些都是根据ImageNet训练的归一化参数,可以直接使用,我们认为这个是固定值就可以


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值