Pytorch基础:数据加载和预处理
Pytorch通过torch.utils.data对数据实现封装,可以容易的实现多线程数据预读和批量加载
import torch
torch.__version__
'1.1.0'
Dataset
Dataset是一个抽象类,为方便读取,需要将使用的数据包装为Dataset类。自定义Dataset需要继承它并实现他的两个方法:
- getitem() 该方法定义用索引(0到self.len)获取一条数据或一个样本
- len() 该方法返回数据总长度
from torch.utils.data import Dataset
import numpy as np
# 定义一个数据类
class Diabetes(Dataset):
def __init__(self):
super(Diabetes, self).__init__()
data = np.loadtxt('.//data//diabetes.csv.gz',
delimiter=',',
dtype=np.float32)
self.len = data.shape[0]
self.x_data = torch.from_numpy(data[:, 0:-1])
self.y_data = torch.from_numpy(data[:, [-1]])
# 根据index返回一行数据
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
# 返回data长度
return self.len
len 方法可以直接使用len获取数据总数
diabetes = Diabetes()
len(diabetes)
759
DataLoader
DataLoader提供了对Dataset的读取操作,常用的参数:batch_size(每个批次大小),shuffle(是否进行shuffle操作),num_workers(加载数据时使用几个子进程)
d = torch.utils.data.DataLoader(diabetes,
batch_size=10,
shuffle=True,
num_workers=0)
DataLoader返回一个可迭代对象,可使用迭代器分批次获取
itdata = iter(d)
next(itdata)
[tensor([[ 0.7647, 0.3668, 0.1475, -0.3535, -0.7400, 0.1058, -0.9360, -0.2667],
[-0.8824, 0.0050, 0.0820, -0.6970, -0.8676, -0.2966, -0.4979, -0.8333],
[-0.5294, 0.3668, 0.1475, 0.0000, 0.0000, -0.0700, -0.0572, -0.9667],
[ 0.0000, 0.4171, 0.0000, 0.0000, 0.0000, 0.2638, -0.8915, -0.7333],
[-0.7647, 0.0854, 0.0164, -0.3535, -0.8676, -0.2489, -0.9573, 0.0000],
[-0.5294, 0.1256, 0.2787, -0.1919, 0.0000, 0.1744, -0.8651, -0.4333],
[-0.7647, -0.1859, -0.0164, -0.5556, 0.0000, -0.1744, -0.8190, -0.8667],
[-0.7647, 0.1859, 0.3115, 0.0000, 0.0000, 0.2787, -0.4748, 0.0000],
[-0.8824, 0.1256, 0.3115, -0.0909, -0.6879, 0.0373, -0.8813, -0.9000],
[ 0.0000, 0.4673, 0.3443, 0.0000, 0.0000, 0.2072, 0.4543, -0.2333]]),
tensor([[0.],
[1.],
[0.],
[0.],
[1.],
[1.],
[1.],
[0.],
[1.],
[1.]])]
# 常见用法是使用for循环遍历
for i, data in enumerate(d):
print(i, data)
break
0 [tensor([[ 0.0000, 0.1859, 0.3770, -0.0505, -0.4563, 0.3651, -0.5961, -0.6667],
[ 0.0588, -0.1055, 0.0164, 0.0000, 0.0000, -0.3294, -0.9453, -0.6000],
[ 0.0000, 0.1759, 0.0820, -0.3737, -0.5556, -0.0820, -0.6456, -0.9667],
[ 0.1765, 0.6884, 0.2131, 0.0000, 0.0000, 0.1326, -0.6080, -0.5667],
[-0.7647, 0.4673, 0.0000, 0.0000, 0.0000, -0.1803, -0.8617, -0.7667],
[-0.1765, 0.1457, 0.2459, -0.6566, -0.7400, -0.2906, -0.6687, -0.6667],
[-0.7647, 0.1256, 0.0820, -0.5556, 0.0000, -0.2548, -0.8044, -0.9000],
[-0.8824, 0.3367, 0.6721, -0.4343, -0.6690, -0.0224, -0.8668, -0.2000],
[-0.8824, 0.6784, 0.2131, -0.6566, -0.6596, -0.3025, -0.6849, -0.6000],
[-0.8824, 0.1658, 0.2787, -0.4141, -0.5745, 0.0760, -0.6430, -0.8667]]), tensor([[0.],
[1.],
[1.],
[0.],
[0.],
[1.],
[1.],
[0.],
[0.],
[1.]])]
torchvision
torchvision是Pytorch中用来处理图像的库
torchvision.datasets 为Pytorch官方定义的dataset:可直接使用MNIST、COCO、Detetion、LSUN、CIFAR10等
from torchvision import datasets, transforms
trainset = datasets.MNIST(
root='.//data//', # 加载MNIST数据的目录
train=True, # 标识加载数据集,为false时为测试集
download=True, # 是否自动下载数据
transform=True) # 是否需要对数据进行预处理, None时不进行预处理
torchvision.models
torchvision还提供了训练好的模型,可以在进行迁移学习torchvision.models模块的子模块中包含以下结构:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
from torchvision import models
resnet18 = models.resnet18(pretrained=True)
Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to C:\Users\Zephyrus/.cache\torch\checkpoints\resnet18-5c106cde.pth
---------------------------------------------------------------------------
torchvision.transforms
transforms模块提供了一般的图像转换操作类,用于数据处理和数据增强
from torchvision import transforms
transform = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 先四周填充0,把图像随机裁剪成32x32
transforms.RandomHorizontalFlip(), # 把图像一般概率翻转,一半的概率不翻转
transforms.RandomRotation((-45, 45)), # 随机旋转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.229, 0.224, 0.225)) # RGB每层的归一化用到的均值和方差
])
关于(0.4914, 0.4822, 0.4465),(0.229, 0.224, 0.225)详情说明,这些是根据ImageNet训练的归一化参数,可以直接使用,可认为为固定值