深度学习初学者,如何下载常用公开数据集并使用呢?
1.前言
刚开始进行深度学习的时候,难免要用到一些公开数据集,现在闲来无事,记录一下如何快速下载一些经典数据集。通过官方文档学习,是一些大牛们挂在嘴边经常推荐的方法,那么我们本篇博客就从官方文档开始学习。
因为我是做CV方向的,所以用TorchVision这个库举例。来自官网:This library is part of the [PyTorch](http://pytorch.org/) project. PyTorch is an open source machine learning framework.
The [torchvision] package consists of popular datasets, model architectures, and common image transformations for computer vision.
包括很多流行数据集,如我们常见的CIFAR,COCO和MINST,大家应该都不陌生。一会儿会以CIFAR举例,记录一下我的过程。
2.官方文档怎样看
-
首先我们看一下
CIFAR
这个类的文档:参数:
root:表示将下载的数据集放在哪个目录
root (string): Root directory of dataset where directory ``cifar-10-batches-py`` exists or will be saved to if download is set to True.
train:是否为训练数据集
train (bool, optional): If True, creates dataset from training set, otherwise creates from test set.
transform:一个将图像进行预处理、返回transform的函数
A function/transform that takes in an PIL image and returns a transformed version.
download:是否下载数据集,
download (bool, optional):If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
3.动手写代码
-
示例代码
# 导入torchvision包 import torchvision # 对原始图像进行数据处理的函数 dataset_transform = torchvision.transforms.Compose([ torchvision.transforms.ToTensor() ]) # 生成训练数据集和测试数据集 # 训练数据集 存放在根目录的dataset文件夹下,作为训练数据集,并下载 train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True) # 测试数据集 存放在根目录的dataset文件夹下,不作为训练数据集,并下载 test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True) print(test_set[0])
-
然后我们右键运行,进行下载
可以看到数据集已经开始下载了,但是因为是从toronto.edu下载,速度很慢。教你一个更快的方法:我们终止运行,复制这个链接,用迅雷下载,很快就好了。然后将下载好的
.gz
文件进行解压,放到我们创建的dataset
目录下: -
重新run,就可以正常使用数据集了。
4.如何可视化
我用tensorboard
进行了可视化,大家感兴趣可以研究一下tensorboard这个库。
import torchvision
from torch.utils.tensorboard import SummaryWriter
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
dataset_transform = torchvision.transforms.Compose([
torchvision.transforms.ToTensor()
])
# 返回类型
train_set = torchvision.datasets.CIFAR10(root="./dataset", train=True, transform=dataset_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset", train=False, transform=dataset_transform, download=True)
print(test_set[0])
writer = SummaryWriter("p10")
for i in range(10):
img, target = test_set[i]
writer.add_image("test_set", img, i)
writer.close()
在浏览器上就可以看到图像啦:

遇到问题:ssl.SSLCertVerificationError: [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1131)
如果在下载中遇到同样的问题,需要导入ssl:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
说在最后的话:编写实属不易,若喜欢或者对你有帮助记得点赞 + 关注或者收藏哦~