Pytorch教程中导入FashionMNIST数据集无法找到
本地导入
在https://pytorch.org/tutorials/中学习pytorch课程时
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root="data",
train=False,
download=True,
transform=ToTensor()
)
在代码中,root=“data” 表示数据集将被下载到名为 “data” 的文件夹中。如果该文件夹不存在,将会自动创建。您可以根据需要更改存储路径。
train=True 表示加载训练集数据,而 train=False 表示加载测试集数据。
download=True 表示如果数据集文件不存在,则自动下载 FashionMNIST 数据集。如果您之前已经下载过数据集并将其保存在指定的 root 文件夹中,可以将该参数设置为 False,以避免重新下载数据集。
transform=ToTensor() 表示对加载的图像进行转换,将其转换为张量形式。ToTensor() 转换将图像数据从 PIL.Image 对象转换为 torch.Tensor 对象,并且将像素值归一化到范围 [0, 1]。
在这一部分中,我是通过https://github.com/zalandoresearch/下载FashionMNIST数据集,再进行本地导入,在进行本地导入时,应该修改为
`training_data = datasets.FashionMNIST(
root="D:\data",#你自己的文件夹
train=True,
download=False,#指不进行网络下载
transform=ToTensor()
)#test_data同样更改
但总会提示Dataset not found. You can use download=True to download it
。
后来终于发现问题,你需要将
这些gz后缀的文件进行解压,
并且将这一文件夹名字改为“raw”,将raw文件夹外的文件夹名字改为“FashionMNIST”,以我本人电脑为例,t10k-images-idx3-ubyte.gz文件的路径应为‘D:…\pytorch\FashionMNIST\raw\t10k-images-idx3-ubyte.gz’。
所以我们的代码应该写为
training_data = datasets.FashionMNIST(
root=r'D:\...\pytorch',
train=True,
download=False,
transform=ToTensor()
)
这样才可以导入FashionMNIST数据集。
直接网络下载
直接下载就是写为“download=True”,root里的路径可以设置自己喜欢的文件夹,最后会在选择的文件夹中自动新建“FashionMNIST”文件夹。
training_data = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor()
)