Pytorch深度学习-torchvision中的数据集使用(小土堆)

  1. 在https://pytorch.org/vision/stable/datasets.html中存在很多官方数据集
  2. 使用示例(CIFAR10数据集使用)
'''
1. root (str 或 `pathlib.Path` ) – 数据集的根目录 `cifar-10-batches-py` ,如果 download 设置为 True,则该目录存在或将保存到其中
2. train (bool, optional) – 如果为 True,则从训练集创建数据集,否则从测试集创建数据集。
3. transform (callable, optional) – 接受 PIL 图像并返回转换后的版本的函数/转换。例如, `transforms.RandomCrop`
4. target_transform (callable, optional) – 接收目标并对其进行转换的函数/转换。
5. download (bool, 可选) – 如果为 true,则从 Internet 下载数据集并将其放在根目录中。如果已下载数据集,则不会再次下载。
'''
torchvision.datasets.CIFAR10(root:Union[str,Path], train:bool=True, transform:Optional[Callable]=None, target_transform:Optional[Callable]=None, download:bool=False
  1. 代码示例
import torchvision  
  
#下载地址为本目录的dataset文件夹下  
#因为train为true所以从训练集下载数据  
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,download=True)  
#因为train为false所以从测试集下载数据  
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,download=True)

print(test_set[0])
#获取数据集特征名  
print(test_set.classes)  
#使用img获取数据集的图片,使用target获取目标值名  
img,target = test_set[0]  
print(img)  
print(target)
  1. 将获取到的图片进行格式转换
#这里转换主要是设置了一个转换器并将其放置于Compose中,设置为一个转换器,并传入在数据获取的参数之中
import torchvision  
datasets_transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  
#下载地址为本目录的dataset文件夹下  
#因为train为true所以从训练集下载数据  
train_set = torchvision.datasets.CIFAR10(root="./dataset",train=True,transform=datasets_transform,download=True)  
#因为train为false所以从测试集下载数据  
test_set = torchvision.datasets.CIFAR10(root="./dataset",train=False,transform=datasets_transform,download=True)
PyTorch是一个用于深度学习的开源框架,它提供了一组工具和接口,使得我们可以轻松地进行模型训练、预测和部署。在PyTorch,数据处理是深度学习应用的重要部分之一。 PyTorch的数据处理主要涉及以下几个方面: 1.数据预处理:包括数据清洗、数据归一化、数据增强等操作,以提高模型的鲁棒性和泛化能力。 2.数据加载:PyTorch提供了多种数据加载方式,包括内置的数据集、自定义的数据集和数据加载器等,以便我们更好地管理和使用数据。 3.数据可视化:为了更好地理解数据和模型,PyTorch提供了多种数据可视化工具,如Matplotlib、TensorBoard等。 下面是一个简单的数据预处理示例,展示如何将图像进行归一化和数据增强: ```python import torch import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 # 定义一个数据预处理管道 transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) ]) # 加载CIFAR10数据集,进行预处理 trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) ``` 在上面的例子,我们首先定义了一个数据预处理管道,其包括了对图像进行随机裁剪、水平翻转、归一化等操作。然后,我们使用PyTorch内置的CIFAR10数据集,并将其预处理后,使用DataLoader进行批量加载。这个过程可以帮助我们更好地管理和使用数据,同时提高模型的训练效率和泛化能力。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值