PyTorch 基础 :数据的加载和预处理
PyTorch通过torch.utils.data对一般常用的数据加载进行了封装,可以很容易地实现多线程数据预读和批量加载。 并且torchvision已经预先实现了常用图像数据集,包括前面使用过的CIFAR-10,ImageNet、COCO、MNIST、LSUN等数据集,可通过torchvision.datasets方便的调用
import torch
torch.__version__
'1.2.0'
Dataset
Dataset是一个抽象类, 为了能够方便的读取,需要将要使用的数据包装为Dataset类。 自定义的Dataset需要继承它并且实现两个成员方法:
__getitem__() 该方法定义用索引(0 到 len(self))获取一条数据或一个样本
__len__() 该方法返回数据集的总长度
下面我们使用kaggle上的一个竞赛bluebook for bulldozers自定义一个数据集,为了方便介绍,我们使用里面的数据字典来做说明(因为条数少)
from torch.utils.data import Dataset
import pandas as pd
class BulldozerDateset(Dataset): #继承Dataset类
"""数据集演示"""
def __init__(self,csv_file):
self.df = pd.read_csv(csv_file)
def __len__(self):
#返回df的长度
return len(self.df)
def __getitem__(self,idx):
# 根据idx返回一行数据
return self.df.iloc[idx]['x1'] #读取下标为idx的一行数据。可以在后面加 .列名 的方式读取属于某一列的数据,否则就会读取这一整行
data = BulldozerDateset(r'data/datatest_1.csv')
# 实现了 __len__ 方法所以可以直接使用len()获取数据总数
len(data)
25
#用索引可以直接访问对应数据,对应 __getitem__方法:
for i in range(len(data)):
print(data[i])
# x1 , x2 = data[i]
# print("x1 = ",x1,"x2 = ",x2)
0.232991543
0.449915356
0.840922298
0.20727367
0.541869015
0.36092917399999996
0.668949803
0.15037799400000001
0.898436358
0.302533521
0.286306281
0.8299904779999999
0.69587861
0.848728081
0.527168802
0.231401747
0.5379995989999999
0.6874250009999999
0.5887086970000001
0.252984848
0.067723107
0.220923151
0.49844191299999996
0.886522632
0.5958730729999999
Dataloader
DataLoader为我们提供了对Dataset的读取操作,常用参数有:batch_size(每个batch的大小), shuffle(是否进行shuffle操作), num_workers(加载数据的时候使用几个子进程),下面做一个简单的操作