torchvision数据集下载位置修改

torchvision数据集下载位置修改

记录一下,方便后边查阅。

1.torchvision基础介绍

torchvision是pytorch的一个图形库,它服务于深度学习Pytorch框架,主要用来构建计算机视觉模型。
下面是torchvision的构成[1]

	1.torchvision.datasets:一些加载数据的函数及常用的数据集接口;
	2.torchvision.models:包含常用的模型结构,例如AlexNet,VGG,ResNet等;
	3.torchvision.transforms:常用的一些图片变换,例如图片裁剪、选择等;
	4.torchvison.utils:其他一些有用的方法

2.torchvision常用数据集

在这里插入图片描述
官方的数据集用法:

在这里插入图片描述
torchvision下的常用数据集:
在这里插入图片描述

3.数据集用法介绍及root路径解释

代码示例:

def get_train_dataset():
    return dataset.FashionMNIST(
        root='./data',
        train=True,
        download=True,
        transform=getTransforms()
其中,
root:表示数据集下载保存位置
train:表示下载的数据集是不是训练集,True表示训练集,False表示测试集
download:表示数据集是否需要下载
transform:表示图片变换的一系列操作

root路径详解:

root='/':表示根目录下,如果你的代码保存在D盘,就是下载到D盘根目录下
root='./':表示当前文件夹下
root='':效果等同于'./'
root='./data':表示在当前文件夹下的data(如果没有,则会新建一个)文件夹下保存数据集
root='data':效果等同于'./data'

4.本文训练LeNet的数据处理代码

代码出自文献[2]:

FashionMNIST.py

import torch.utils.data
from torchvision import datasets as dataset
from torchvision import transforms


def getTransforms():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3018,))]
    )
    return transform


def get_train_dataset():
    return dataset.FashionMNIST(
        root='',
        train=True,
        download=True,
        transform=getTransforms()
    )


def get_test_dataset():
    return dataset.FashionMNIST(
        root='',
        train=False,
        download=True,
        transform=getTransforms()
    )


def get_train_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_train_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )


def get_test_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_test_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )

MNIST.py

import torch.utils.data
from torchvision import datasets as dataset
from torchvision import transforms


def getTransforms():
    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3018,))]
    )
    return transform


def get_train_dataset():
    return dataset.MNIST(
        root='LeNetTest',
        train=True,
        download=True,
        transform=getTransforms()
    )


def get_test_dataset():
    return dataset.MNIST(
        root='LeNetTest',
        train=False,
        download=True,
        transform=getTransforms()
    )


def get_train_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_train_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )


def get_test_loader(batch_size, shuffle=True):
    return torch.utils.data.DataLoader(
        dataset=get_test_dataset(),
        batch_size=batch_size,
        shuffle=shuffle
    )

第一次写,可能写得不是很好,希望大家多多包涵!

文献

[1]:https://wenku.baidu.com/view/21bbc06bf4ec4afe04a1b0717fd5360cba1a8df6.html
[2]:https://blog.csdn.net/weixin_38878828/article/details/125614377

  • 2
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

只想睡觉111

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值