一、torchvision中数据集的使用
torchvision中的datasets类提供了一些常用的数据集,下面以CIFAR10数据集作为例子介绍一下它的用法root:
1.函数参数
root:数据集的根目录,数据集将会保存在该目录下
tran:默认为True,True为训练好的数据集,False为测试数据集
transfrom:接受PIL图片并返回转换版本的函数
target_transform:接受目标并返回转换版本的函数
download:bool默认为False,True则会下载该数据集
2.数据集
print(test_set[0]) # 打印数据集中的第一个图片的信息是一个元组(img, target)
print(test_set.classes)
classes是一个列表,包括这个数据集中包含的图片信息
img, target = test_set[0] # target是图片的classes信息
print(img)
print(target)
print(test_set.classes[target])
img.show()
运行结果
打开图片确实应该是一只小猫
3.和transfrom的联合使用
writer = SummaryWriter("log2")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
二、dataloader的使用
一中对datasets中的数据集的使用相当于创建了一副纸牌,而dataloader就用于发牌以及使用这副牌
DataLoader的常用参数:
dataset:处理的数据集
batch_size:每次批量处理的数量
shuffle:是否打乱顺序
num_workers:加载数据集的线程数目
drop_last:当样本数量不能被batch_size整除时,是否舍去最后一批数据
仍然以CIFAR10数据集作为例子
import torchvision
# 准备的测试数据集
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
test_data = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=torchvision.transforms.ToTensor())
test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
# 测试数据集中第一张图片及target
img, target = test_data[0]
print(img.shape)
print(target)
writer = SummaryWriter("dataloader")
for epoch in range(2):
step = 0
for data in test_loader:
imgs, targets = data
# print(imgs.shape)
# print(targets)
writer.add_images("Epoch: {}".format(epoch), imgs, step)
step = step + 1
writer.close()