【Pytorch学习】-- torchvision.datasets&DataLoader使用

学习视频:https://www.bilibili.com/video/BV1hE411t7RN?p=1,内含环境搭建

torchvision.datasets使用

Pytorch有自带许多数据集可供学习使用,因此,本次挑选一个来进行学习,官方地址

CIFAR10

本次使用的是CIFAR10数据集。

import torchvision
train_set = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",train = True,download = True)
test_set = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",train = False,download = True)
# root:下载的根目录
# train:true为训练集,false为测试集
# download:是否从网上下载,若下载速度慢可以用迅雷下载:
# https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

查看一下数据集的数据是什么类型的

print(test_set[0])

由输出结果可以看出是一个PIL Image类型,"3"代表类别时三,CIFAR10有10个类别,具体类别参考官方文档

(<PIL.Image.Image image mode=RGB size=32x32 at 0x281591063D0>, 3)

转化为Tensor类型

在调用参数中多加一个自带的transform的参数:

test_data = torchvision.datasets.CIFAR10(root = "./CIFAR10_Dataset",
                                                    train = False,
                                                    transform = torchvision.transforms.ToTensor())
print(test_data[0])                                                                                                     

得到的就是tensor数据类型

(tensor([[[0.6196, 0.6235, 0.6471,  ..., 0.5373, 0.4941, 0.4549],
         [0.5961, 0.5922, 0.6235,  ..., 0.5333, 0.4902, 0.4667],
         [0.5922, 0.5922, 0.6196,  ..., 0.5451, 0.5098, 0.4706],
         ...,
         [0.2667, 0.1647, 0.1216,  ..., 0.1490, 0.0510, 0.1569],
         [0.2392, 0.1922, 0.1373,  ..., 0.1020, 0.1137, 0.0784],
         [0.2118, 0.2196, 0.1765,  ..., 0.0941, 0.1333, 0.0824]],

        [[0.4392, 0.4353, 0.4549,  ..., 0.3725, 0.3569, 0.3333],
         [0.4392, 0.4314, 0.4471,  ..., 0.3725, 0.3569, 0.3451],
         [0.4314, 0.4275, 0.4353,  ..., 0.3843, 0.3725, 0.3490],
         ...,
         [0.4863, 0.3922, 0.3451,  ..., 0.3804, 0.2510, 0.3333],
         [0.4549, 0.4000, 0.3333,  ..., 0.3216, 0.3216, 0.2510],
         [0.4196, 0.4118, 0.3490,  ..., 0.3020, 0.3294, 0.2627]],

        [[0.1922, 0.1843, 0.2000,  ..., 0.1412, 0.1412, 0.1294],
         [0.2000, 0.1569, 0.1765,  ..., 0.1216, 0.1255, 0.1333],
         [0.1843, 0.1294, 0.1412,  ..., 0.1333, 0.1333, 0.1294],
         ...,
         [0.6941, 0.5804, 0.5373,  ..., 0.5725, 0.4235, 0.4980],
         [0.6588, 0.5804, 0.5176,  ..., 0.5098, 0.4941, 0.4196],
         [0.6275, 0.5843, 0.5176,  ..., 0.4863, 0.5059, 0.4314]]]), 3)

DataLoader使用

Dataloader就是把数据整理成适合输入到神经网络形式的工具

from torch.utils.data import DataLoader

test_loader = DataLoader(dataset = test_data,
                        batch_size = 4,   	#  how many samples per batch to load 一次取多少
                        shuffle = True,  	# set to True to have the data reshuffled at every epoch 是否打乱
                        num_workers = 0, 	# how many subprocesses to use for data loading 多少个子进程进行读取
                        drop_last = False)  # set to True to drop the last incomplete batch,if the dataset size is not divisible by the batch size. 是否丢弃最后不能形成一个batch的数据

for data in test_loader:
    imgs,targets = data
    print(imgs.shape) # batch_size = 4:四张图,每张图3通道,像素为32*32
    print(targets)    # 四张图的所对应的类别

一个输出结果

torch.Size([4, 3, 32, 32])
tensor([0, 8, 5, 1])

借助TensorBoard更直观地查看:

from torch.utils.tensorboard import SummaryWriter

test_loader = DataLoader(dataset = test_data,
                        batch_size = 64,   
                        shuffle = True,  
                        num_workers = 0, 
                        drop_last = False)

writer = SummaryWriter("dataloader")
step = 0
for data in test_loader:
    imgs,targets = data
    writer.add_images("test_data",imgs,step)
    step = step + 1
    
writer.close()

每一步有64张图
请添加图片描述

  • 3
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
引用中提到,torchvision.datasets.CIFAR10是torch.utils.data.Dataset的子类,因此可以传递给torch.utils.data.DataLoader使用。根据引用中的示例代码,可以看到在使用torchvision.datasets.CIFAR10时,需要指定数据集的根目录、是否训练集、是否下载、以及数据的转换方式。同时还可以指定批量大小、是否随机打乱数据和使用的工作线程数等参数。在引用中也提到,如果要训练集和测试集都使用相同的数据集,并且已经下载了数据集,可以将download参数设置为False来避免重新下载数据。所以如果torchvision.datasets.CIFAR10无法使用,可能是因为没有指定正确的路径、没有下载数据集或者其他参数设置有误。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [cifar100coarse:使用粗糙标签构建PyTorch CIFAR100](https://download.csdn.net/download/weixin_42140625/15279258)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [torchvision.datasets.CIFAR10模块使用讲解](https://blog.csdn.net/weixin_44901043/article/details/123864690)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [pyTorch: torchvision.datasets加载CIFAR10报错](https://blog.csdn.net/Vertira/article/details/122560679)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值