在pytorch中使用自己的数据集,dataset的写法

引入

在学习pytorch的过程中,用的一直都是教程中别人定义好从网上直接下载的数据集,不需要进行任何的处理,数据和标号都可以直接获取。但是,我想要进行自己的研究大多数情况需要我们自己收集数据并进行一些预处理在制作成数据集,然后通过pytorch读入后用来训练模型。这里记录的是一次对上万张验证码图片组成的数据集(标号是其名称)制作pytorch数据集的尝试。

部分数据如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BF0drwOa-1647960599298)(attachment:51c03f9d-a46c-4994-8e7e-36c078c6a724.png)]

大多数教程中并没有讲这些图片数据和标签是如何装载到torch中的,在分析了一个github项目https://github.com/braveryCHR/CNN_captcha 后我大概了解如何装载数据。

方法

如果我们需要利用pytorch装载数据以及标签,我们就必须自己写一个dataset类,该类要继承data.Dataset类,该类在torch.utils中,并实现该类的_getitem_和_len_方法。
示例:

为了实现将验证码分类,我们先定义label和字符互相转换的函数:

import os

import torch
from PIL import Image
from torch.utils import data
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms as T

def StrToLabel(Str):
    # print(Str)
    label = []
    for i in range(0, charNumber):
        if '0' <= Str[i] <= '9':  # 数字
            label.append(ord(Str[i]) - ord('0'))
        elif 'a' <= Str[i] <= 'z':  # 小写字母
            label.append(ord(Str[i]) - ord('a') + 10)
        else:  # 大写字母
            label.append(ord(Str[i]) - ord('A') + 36)
    return label


def LabelToStr(Label):
    Str = ""
    for i in Label:
        if i <= 9:
            Str += chr(ord('0') + i)
        elif i <= 35:
            Str += chr(ord('a') + i - 10)
        else:
            Str += chr(ord('A') + i - 36)
    return Str

接下来是数据集合类的定义

class Captcha(data.Dataset):
    def __init__(self, root, train=True):
        self.imgPath = [os.path.join(root, img) for img in os.listdir(root)]
        self.transform = T.Compose([
            T.Resize((150, 30)),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __getitem__(self, index):
        img_path = self.imgPath[index]
        label = img_path.split('\\')[-1].split('.')[0]       #获取图片标签
        label_tensor = torch.Tensor(StrToLabel(label))
        data=Image.open(img_path)
        data = self.transform(data)  # 使用PLT打开图片文件
        return data, label_tensor

    def __len__(self):
        return len(self.imgPath)

在init中的transform是预处理的定义。

getitem方法用来返回读取的图片数据和该图片的参数,我们将图片文件名获取到并转换为tensor,再使用PIL模块中的Image.open()读取图片数据,之后通过预处理transform转为tensor对象,最后返回图片数据data和图片标签label_tensor就可以了。


len函数返回文件中图片的数量。


dataloader会根据len读取文件中所有图片,每次读取图片的方法就是getitem中定义的方法。

测试

我们来使用一下这个Capthca类,看看能否正确读取图片数据data以及其标号label

import os.path
import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹

img_data = Captcha("./data/train/train", train=True)
trainDataLoader = DataLoader(img_data, batch_size=1,
                             shuffle=False, num_workers=4)

if __name__ == '__main__':
    # for i, data in enumerate(trainDataLoader, 0):
    #     inputs, label = data
    #     print(label)
    it = trainDataLoader.__iter__()#使用迭代器返回第一张图片的数据和标签
    data, label = it.next()
    print(data)
    print(label)
    print(LabelToStr(int(x)for x in label.squeeze().tolist()))

由于在jupyter中运行该代码会报错所以我放上在pycharm上的运行结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bjSdGSsA-1647960599299)(attachment:f17724e4-267e-40ad-ab0d-5f041787eee2.png)]

总结

想要使用自己定义的数据集就必须实现一个dataset,使得dataloader知道如何获取数据以及标签。

  • 4
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
好的,我来给你讲一下如何用 PyTorch 加载自己的数据集。 首先,你需要创建一个自定义的 Dataset 类,它必须包含两个方法:\_\_len\_\_ 和 \_\_getitem\_\_。 ```python import torch.utils.data as data class MyDataset(data.Dataset): def __init__(self, data_path): # 初始化数据集 self.data = [] with open(data_path, 'r') as f: for line in f: self.data.append(line.strip()) def __len__(self): # 返回数据集的长度 return len(self.data) def __getitem__(self, index): # 根据索引返回一条数据 return self.data[index] ``` 在上面的代码,我们首先导入了 PyTorch 的 data 模块,然后定义了一个 MyDataset 类。这个类的构造函数需要传入数据集的路径,然后读取数据集并进行初始化。在 \_\_len\_\_ 方法,我们返回了数据集的长度,这个方法会被 DataLoader 调用以确定数据集的大小。在 \_\_getitem\_\_ 方法,我们根据索引返回一条数据,这个方法会被 DataLoader 调用以获取数据。 接下来,我们需要创建一个 DataLoader 对象来读取数据集。DataLoader 会按照一定的 batch_size 对数据集进行分批,并且提供数据的迭代器。 ```python from torch.utils.data import DataLoader # 创建数据集 dataset = MyDataset('data.txt') # 创建 DataLoader dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4) ``` 在上面的代码,我们首先创建了一个 MyDataset 对象,然后使用 DataLoader 将这个数据集分批。batch_size 参数指定了每个 batch 包含的样本数,shuffle 参数指定了是否打乱数据集,num_workers 参数指定了使用多少个子进程来读取数据集。 现在,我们可以通过迭代 DataLoader 来读取数据了。 ```python for batch in dataloader: # 处理数据 pass ``` 在上面的代码,我们通过迭代 dataloader 来读取数据集,每个 batch 的数据会被封装成一个 Tensor 对象,我们可以直接对这个 Tensor 进行操作。 希望这个回答能够帮助到你!

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值