前言
本文主要总结一下使用pytorch
过程中几种对数据集的读取与使用,主要领域涉及CV
和NLP
CV
pytorch中的CV
数据集主要包含torchvision
提供的预先处理好的数据集,例如MNIST
,cifar10
和现实生活中的图像数据。这里分别用MNIST
和kaggle
猫狗分类数据为例。
MNIST
torchvision
预处理好的数据集都在tochvision.dataset
包中,要导入MNIST
,使用以下代码:
from torchvision import datasets
下面的代码展示了导入MNIST
按照训练集、测试集划分,并按照指定的batch_size
返回数据
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
def load_data(data_dir, batch_size):
train_dataset = datasets.MNIST(root=data_dir, # MNIST存储路径
train=True, # 是否是训练集
transform=transforms.ToTensor(), # 转换图像通道顺序,并标准化(将数据压缩到0~1)
download=True) # 如果目录下没有数据,则下载
test_dataset = datasets.MNIST(root=data_dir,
train=False,
transform=transforms.ToTensor())
# get data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
return train_loader, test_loader
train_loader, test_loader = load_data('./data', 64)
for data, target in train_loader:
print(data.size(),target.size())
break
执行后程序会自动将MNIST原数据集下载到指定的data_dir
下,返回一个dataset
对象,打印这个对象信息得到如下图:
接下来用DataLoader
去读取该数据,用于返回每次提供batch_size
的数据加载器。打印一个batch
的数据结果如下:
猫狗数据集
上面的数据集是pytorhc
为我们处理好的常用数据集,但现实情况下我们通常要对文件(.jpg)形式的数据进行读取与使用,kaggle
的猫狗分类数据集就是类似这种情况。
以猫狗分类数据中的
train
目录下数据