pytorch----数据处理,datasets、DataLoader及其工具的使用

torchvision是PyTorch的一个视觉工具包,提供了很多图像处理的工具。

datasets使用ImageFolder工具(默认PIL Image图像),获取定制化的图片并自动生成类别标签。如裁剪、旋转、标准化、归一化等(使用transforms工具)。

DataLoader可以把datasets数据集打乱,分成batch,并行加速等。

一、datasets获取原图或格式化的图,自动命名标签

1.1 获取原图片

使用torchvision.datasets中的ImageFolder工具,功能:

1、文件夹名就是类别名

2、从上到下自动为文件夹自动创建标签,0、1、2、...。class_to_idx、imgs属性可以查看。

3、返回每一幅图的data、label

 

from torchvision.datasets import ImageFolder

dataset=ImageFolder("E:/data/dogcat_2/train/") #获取路径,返回的是所有图的data、label
print(dataset.class_to_idx) #查看类别名,及对应的标签。
print(dataset.imgs)  #查看路径里所有的图片,及对应的标签

print(dataset[0][1]) #第1张图的label
dataset[0][0] #第1张图的data

1.2 获取定制化的图片,启用ImageFolder的transform参数

使用torchvision的transforms工具,常用功能:

Resize——调整大小
CenterCrop、RandomCrop、RandomSizedCrop——裁剪
Pad——填充
ToTensor——PIL Image转Tensor,自动[0,255]归一化到[0,1]
Normalize——标准化,即减均值,除以标准差
ToPILImage——Tensor转PIL Image

这些操作可以放到一起——Compose

from torchvision import transforms as T

#设置格式化条件
transform=T.Compose([T.Resize((200,200)), #缩放为200*200方形
                     T.RandomHorizontalFlip(), #水平翻转
                     T.ToTensor(), #PIL Image转Tensor,[0,255]自动归一化为[0,1]
                     T.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5]) #标准化,减均值除标准差
                    ])
#启用ImageFolder的transform参数,获取格式化图像
dataset=ImageFolder("E:/data/dogcat_2/train/",transform=transform)

dataset[0][0].size() #查看图像大小,3*224*224

#展示图像,乘标准差加均值,再转回PIL Image(上述过程的逆过程)
show=T.ToPILImage()
show(dataset[0][0]*0.5+0.5)

 

二、DataLoader处理datasets

from torch.utils.data import DataLoader
dataloader=DataLoader( dataset,batch_size=4,shuffle=True,num_workers=2 ) #4幅图为1个batch,打乱,2个进程加速
#### 显示第1个batch的4幅图(随机)
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
dataiter = iter(dataloader) #DataLoader是可迭代的
(images, labels) = dataiter.next() #第一个batch
print(labels) #打印标签
show=ToPILImage() 
show(make_grid(images*0.5+0.5)).resize((4*100,100))  #以100*100展示第一个batch

附:Python transforms.RandomCrop方法代码示例

https://vimsky.com/examples/detail/python-method-torchvision.transforms.RandomCrop.html

以下是一些使用PyTorch Lightning的步骤: 1. 安装PyTorch Lightning:您可以使用pip安装PyTorch Lightning,命令如下: ``` pip install pytorch-lightning ``` 2. 创建LightningModule:LightningModule是PyTorch Lightning的核心组件,它是您定义模型结构和训练循环的地方。您可以创建一个类来定义您的模型和训练循环。以下是一个简单的示例: ``` import torch.nn as nn import pytorch_lightning as pl class MyModel(pl.LightningModule): def __init__(self): super().__init__() self.layer1 = nn.Linear(28*28, 128) self.layer2 = nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = nn.functional.relu(self.layer1(x)) x = self.layer2(x) return x def training_step(self, batch, batch_idx): x, y = batch y_hat = self.forward(x) loss = nn.functional.cross_entropy(y_hat, y) self.log('train_loss', loss) return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-3) ``` 3. 创建DataModule:DataModule是PyTorch Lightning的另一个核心组件,它是您加载和预处理数据的地方。您可以创建一个类来定义如何加载和预处理数据。以下是一个简单的示例: ``` import torchvision.datasets as datasets import torchvision.transforms as transforms import pytorch_lightning as pl class MyDataModule(pl.LightningDataModule): def __init__(self): super().__init__() self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) self.train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=self.transform) self.test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=self.transform) def train_dataloader(self): return torch.utils.data.DataLoader(self.train_dataset, batch_size=32, shuffle=True) def test_dataloader(self): return torch.utils.data.DataLoader(self.test_dataset, batch_size=32, shuffle=False) ``` 4. 创建Trainer:Trainer是PyTorch Lightning的训练器,它负责训练和验证模型。您可以创建一个Trainer对象并传递您的模型和数据模块。以下是一个简单的示例: ``` import pytorch_lightning as pl model = MyModel() data_module = MyDataModule() trainer = pl.Trainer(gpus=1, max_epochs=10) trainer.fit(model, data_module) ``` 在训练完成后,您可以使用以下代码评估模型: ``` trainer.test(datamodule=data_module) ``` 这些步骤可以让您快速开始使用PyTorch Lightning。您可以根据您的需求进一步定制和扩展它们。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值