因为cifar-10数据集在原地址下载速度很慢,所以在网上找百度网盘链接下载,后将压缩包放到指定目录,运行代码出错
报错详情:
Traceback (most recent call last):
File "C:\Users\Administrator\Documents\python\train.py", line 26, in <module>
download=False,transform = transform)
File "E:\Anaconda3\lib\site-packages\torchvision\datasets\cifar.py", line 61, in __init__
raise RuntimeError('Dataset not found or corrupted.' +
RuntimeError: Dataset not found or corrupted. You can use download=True to download it`
代码展示
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
# 下载CIFAR10数据集
trainset = torchvision.datasets.CIFAR10(root='./datasets',train=True,
download=False,transform = transform)
# 加载数据集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True,num_workers=0)
# 下载测试集
testset = torchvision.datasets.CIFAR10(root='./datasets',train=False,
download=False,transform=transform)
# 加载测试集
testloader = torch.utils.data.DataLoader(testset,batch_size=2,shuffle=True,num_workers=0)
解决办法:
将download=False改为download=True
原因
因为如果是压缩包的话,程序无法加载,所以需要解压,download=True时代码会自动检测相应目录下是否有文件,如果有,就不下载,并把文件解压