在pytorch中使用自己的数据集,dataset的写法

引入

在学习pytorch的过程中,用的一直都是教程中别人定义好从网上直接下载的数据集,不需要进行任何的处理,数据和标号都可以直接获取。但是,我想要进行自己的研究大多数情况需要我们自己收集数据并进行一些预处理在制作成数据集,然后通过pytorch读入后用来训练模型。这里记录的是一次对上万张验证码图片组成的数据集(标号是其名称)制作pytorch数据集的尝试。

部分数据如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BF0drwOa-1647960599298)(attachment:51c03f9d-a46c-4994-8e7e-36c078c6a724.png)]

大多数教程中并没有讲这些图片数据和标签是如何装载到torch中的,在分析了一个github项目https://github.com/braveryCHR/CNN_captcha 后我大概了解如何装载数据。

方法

如果我们需要利用pytorch装载数据以及标签,我们就必须自己写一个dataset类,该类要继承data.Dataset类,该类在torch.utils中,并实现该类的_getitem_和_len_方法。
示例:

为了实现将验证码分类,我们先定义label和字符互相转换的函数:

import os

import torch
from PIL import Image
from torch.utils import data
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms as T

def StrToLabel(Str):
    # print(Str)
    label = []
    for i in range(0, charNumber):
        if '0' <= Str[i] <= '9':  # 数字
            label.append(ord(Str[i]) - ord('0'))
        elif 'a' <= Str[i] <= 'z':  # 小写字母
            label.append(ord(Str[i]) - ord('a') + 10)
        else:  # 大写字母
            label.append(ord(Str[i]) - ord('A') + 36)
    return label


def LabelToStr(Label):
    Str = ""
    for i in Label:
        if i <= 9:
            Str += chr(ord('0') + i)
        elif i <= 35:
            Str += chr(ord('a') + i - 10)
        else:
            Str += chr(ord('A') + i - 36)
    return Str

接下来是数据集合类的定义

class Captcha(data.Dataset):
    def __init__(self, root, train=True):
        self.imgPath = [os.path.join(root, img) for img in os.listdir(root)]
        self.transform = T.Compose([
            T.Resize((150, 30)),
            T.ToTensor(),
            T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def __getitem__(self, index):
        img_path = self.imgPath[index]
        label = img_path.split('\\')[-1].split('.')[0]       #获取图片标签
        label_tensor = torch.Tensor(StrToLabel(label))
        data=Image.open(img_path)
        data = self.transform(data)  # 使用PLT打开图片文件
        return data, label_tensor

    def __len__(self):
        return len(self.imgPath)

在init中的transform是预处理的定义。

getitem方法用来返回读取的图片数据和该图片的参数,我们将图片文件名获取到并转换为tensor,再使用PIL模块中的Image.open()读取图片数据,之后通过预处理transform转为tensor对象,最后返回图片数据data和图片标签label_tensor就可以了。


len函数返回文件中图片的数量。


dataloader会根据len读取文件中所有图片,每次读取图片的方法就是getitem中定义的方法。

测试

我们来使用一下这个Capthca类,看看能否正确读取图片数据data以及其标号label

import os.path
import torch
import torchvision
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# 使用torchvision.datasets.ImageFolder读取数据集 指定train 和 test文件夹

img_data = Captcha("./data/train/train", train=True)
trainDataLoader = DataLoader(img_data, batch_size=1,
                             shuffle=False, num_workers=4)

if __name__ == '__main__':
    # for i, data in enumerate(trainDataLoader, 0):
    #     inputs, label = data
    #     print(label)
    it = trainDataLoader.__iter__()#使用迭代器返回第一张图片的数据和标签
    data, label = it.next()
    print(data)
    print(label)
    print(LabelToStr(int(x)for x in label.squeeze().tolist()))

由于在jupyter中运行该代码会报错所以我放上在pycharm上的运行结果:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bjSdGSsA-1647960599299)(attachment:f17724e4-267e-40ad-ab0d-5f041787eee2.png)]

总结

想要使用自己定义的数据集就必须实现一个dataset,使得dataloader知道如何获取数据以及标签。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值