Pytorch框架学习记录4——数据集的使用(torchvision.dataset)
1. 数据集
在pytorch官网中我们可以看到pytorch自身所配有的数据集的情况,以及该数据集的类型、使用方法等。在这里,我们选择数据集较小的CIFAR10作为我们的示例数据集。
该数据集的调用和使用使用代码如下:
torchvision.datasets.CIFAR10(root: str, train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
参数说明:
- root ( string ) – 数据集的根目录,
cifar-10-batches-py
如果下载设置为 True,则该目录存在或将保存到该目录。 - train ( bool , optional ) – 如果为真,则从训练集创建数据集,否则从测试集创建。
- transform ( callable , optional ) – 一个函数/转换,它接受 PIL 图像并返回转换后的版本。例如,
transforms.RandomCrop
- target_transform ( callable , optional ) – 接收目标并对其进行转换的函数/转换。
- download ( bool , optional ) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
2. 使用实例
下载CIFAR10数据集后,将其类型转换为tensor类型,并在tensorboard中进行展示。
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
dataset_transform = transforms.Compose([
transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)
writer = SummaryWriter('logs')
for i in range(10):
img, label = train_set[i]
writer.add_image('train10', img, i)
writer.close()
此外,还可以直接通过链接使用浏览器下载,下载完毕后,在当前目录下也命名一个dataset文件夹并放入,上述代码不做任何改变,会自动将手动下载的数据集进行解压和修正。