pytorch 官方数据集的使用
1.官方数据集的查看
- 登录pytorch官网
- 找到Docs–>torchvision
- 点击左上角Search Docs 上面的小箭头切换版本
- 改为v0.9.0,然后就可以看到官方数据集了(如下)
2.头文件:
import torchvision
3.主要使用的函数:
test_dataset2=torchvision.datasets.CIFAR10(root="./test_dataset2", train=False, transform=data_compose, download=True)
#函数:
#torchvision.datasets.官方数据集名(root, train, download)
#参数:
#root 为数据集本地存放的路径
#train 该数据集是否是为训练集
#download 是否将数据集下载到本地
#(若已下载到对应目录,函数就自动不再重复下载了)
#transform 对数据集中的每一个图片进行相应的处理(可省略)
DataLoader
- 作用:
将指定的目标数据以及对应的label进行打包,并发送给后面的神经网络(为后面网络提供不同的数据形式),简单点来说,就是将指定的dataset按照指定的方式进行打包处理。 - 使用:
引入头文件
from torch.utils.data import DataLoader
主要使用的函数:
test_dataloader=DataLoader(dataset=test_dataset2, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
#参数:
#dataset 要训练的数据集
#batch_size 每一次从数据集中取几份数据
#shuffle 每轮取完数据后是否打乱数据顺序,以便下次取
# drop_last 数据取到最后不足batch_size份,是否丢弃
函数作用:
将数据集中的数据按指定格式打包即每一次从数据集中取几份数据 ,一个包中,所有图片放在一起,所有对应的target放在一起。
函数返回值为所有包的一个集合,若要调用每一个包,可用for循环来依次调用。
3.实例代码及结果:
import torchvision
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
data_compose=torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
test_dataset2=torchvision.datasets.CIFAR10(root="./test_dataset2", train=False, transform=data_compose, download=True)
#数据集的引入
test_dataloader=DataLoader(dataset=test_dataset2, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
#对数据进行打包(每64份数据为一个包)
writer= SummaryWriter(r"E:\python\pythonProject_pyTorch\dataLoader")
step=0
for data in test_dataloader:
imgs, targets = data
# print(imgs.shape) #每次循环有batch_size张图片(图片包)
# print(targets) #图片依次对应的target(target包)
writer.add_images("test_data", imgs, step)
step=step+1
writer.close()
代码中用的是pytorch官网数据集 CIFAR10