解决pytorch下载MNIST数据集缓慢的问题,非常简单

问题描述

train_loader=torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),
    batch_size=BATCH_SIZE,shuffle=True)

当我们使用上述代码加载MNIST数据时,数据集的下载非常缓慢,经常卡死。

解决方法

进入MNIST数据集网站,下载4个数据集压缩包,不要解压(因为pytorch还会对数据集进行一些处理,而不是单纯的解压,所以直接手动解压是不行的)。4个压缩包放在/data/MINIST/raw文件夹内,如下图所示:
在这里插入图片描述这里我的文件路径是D:\aaa\pytorch-exercise\0my-exercise\data\MNIST\raw

之后,在python的库中找到torchvison库,再找到实现mnist加载的文件mnist.py。其在python中的大概路径如下:[你的python的位置]\lib\site-packages\torchvision\datasets\mnist.py

打开mnist.py,将class MNIST的mirrors变量改成 file:///后面加上上述4个压缩包所在的路径,resources变量不要改动,如下图所示:
在这里插入图片描述

mirrors中的路径,要写你的电脑中那4个压缩包所在的路径。

修改好后,再运行一遍加载MNIST数据的代码块,发现成功加载好了数据,如下:
在这里插入图片描述
去/data/MNIST/raw文件夹内查看,多出了4个文件:
在这里插入图片描述
至此,问题就解决了。我看网上其他博客,说raw的同级目录processed目录中还应该有两个文件,但我没管那两个文件,我的模型一样成功跑起来了。那就先不管那两个文件了。

具体的原理,大家可以看一下mnist.py的代码,尤其是self.download函数。
大家的torchvision的版本可能不一样。我还看过其他人的类内变量为url,而不是mirrors和resources,其实如果是url改动方法也和上面类似(其实更好改了)。

  • 7
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值