pytorch入门(三)

本文是PyTorch入门系列的第三篇,介绍了PyTorch中的数据处理,包括Dataset和DataLoader的使用,以及torchvision包的功能。Dataset用于存储样本及其标签,DataLoader实现批量数据加载。torchvision提供了常见数据集、模型结构和图像转换工具。文章还简要提及了使用CUDA加速训练、模型保存与加载、模型训练进度的可视化。
摘要由CSDN通过智能技术生成

在前两篇文章中我们了解了pytorch的一些基础概念以及模型训练的整体框架。今天我们继续介绍pytorch的一些常用包。

传送门:

pytorch入门(一)

pytorch入门(二)

Dataset&DataLoader


对于数据处理上,我们希望样本数据集代码和模型训练代码分开,这样便于获得更强的可读性和模块化。因此,pytorch提供了两个数据元素:

from torch.utils.data import Dataset
from torch.utils.data import DataLoader

Dataset存储样本示例以及其对应的标签;DataLoader包装了数据集的可迭代对象,方便访问样本,可以认为是一个数据加载器。

Dataset是一种抽象类,所有继承它的子类都应该包含__len__方法和__getitem__方法,前者表示数据集大小,后者支持实例化后的样本整数索引。一般情况下,继承此类的对象也都包含初始化方法__init__.

下面举一个例子:

import os
import pandas as pd
from torchvision.io import read_image


class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform


    def __len__(self):
        return len(self.img_labels)


    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

img_dir是存储数据(这个例子应该是图片)路径,annotations_file是存储图片对应标签的CSV文件。然后,方法__len__返回的是数据集所有图片的个数__getitem__方法是在给定样本的索引后,加载样本数据以及其对应的标签。两个tranform是做一些必要的转换,这里先不用管。

在训练模型的时候,一般只取批量样本进行训练,因此,DataLoader实现这种“小批量”训练过程:

from torch.utils.data import DataLoader


DataLoader(
  • 3
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

整得咔咔响

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值