1.进入PyTorch官网
点击torchvision
2.在左上角版本选择切换到0.9.0版本即可看到torchvision的数据集目录
3.点击想要下载的数据集名称即可看到该类需要写入的参数有哪些
4.在pycharm中通过代码实现自动下载:
import torchvision
#自动下载数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)
#print(test_set[0]) #测试集第一个数据
#print(test_set.classes) #测试集的类型
#img,target = test_set[0]
#print(img)
#print(target)
#print(test_set.classes[target]) #输出测试集类型名称
#img.show() #图片查看
其中root代表你的数据集下载保存的位置,train为True则表示这是训练集,False代表是测试集,download为True表示自动下载,False表示非自动下载。
注:点击运行后如果下载速度过慢,可选择手动下载,通过迅雷加速下载,下载后手动导入root设定的目录下,即可使用
5.使用tensorboard展示图片:需要先把所有PIL格式的图片通过transforms的Compose方法转换为Tensor格式,再去进行展示
import torchvision
from torch.utils.tensorboard import SummaryWriter
dataset_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()]) #对数据集中每一张图片都转换为tensor类型
#自动下载数据集
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=dataset_transform,download=True)
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=dataset_transform,download=True)
#用tensorboard进行显示
writer = SummaryWriter("dataexample")
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
在这里我犯了一个错误:
、
原因是我在编写代码时忘记了数据集中包括了图片和标签,代码直接编写为下面这样
#用tensorboard进行显示
writer = SummaryWriter("dataexample")
for i in range(10):
writer.add_image("test_set",test_set[i],i)
writer.close()
改为下面这样才正确,只需要img添加图片就可以了。
#用tensorboard进行显示
writer = SummaryWriter("dataexample")
for i in range(10):
img,target = test_set[i]
writer.add_image("test_set",img,i)
writer.close()
6.最终运行结果: