pytorch框架学习(dataset类代码)

目录

前言:

什么是dataset?

dataset类

python的os库

python的PIL库

dataset子类代码

导库:

模板:

__init__:

__getitem__:

__len__:

子类完整代码:

测试:

尾声:

补充:


前言:

该专栏的前几篇介绍完了yolov5的基本框架,本章开始将进行pytorch的框架学习,如果你也跟我一样是个深度学习小白想要学习yolo算法,下面给出一点学习方向的建议——


什么是dataset?

dataset直译为数据集,个人理解成两个含义:

①数据集文件:也就是从网络上、或者别人分享,然后我们下载、解压后放在电脑硬盘中的一个文件,常见的数据集由训练集(train)和测试集(val)两部分组成,两者都有数量很多的图片,前者用来训练模型,后者用来测试训练好的模型

在train集和val集里会根据label(标签)进行图片分类,比如下图的蚂蚁(ants)和蜜蜂(bees)

②数据集“变量”:这个是指数据集在代码中存在的形式,显然,如果我们想要在python代码中对图片进行一些相关操作(查看图片、获取图片信息、提取特征等等……),就需要一些处理,把存放在硬盘中的图片文件的相关信息(比如图片的存放路径、图片的尺寸大小)转换成代码


dataset类

下载、解压一个数据集是再简单不过的操作,没什么好讲的,接下来就从具体代码入手。在此说明:类是面向对象程序设计的最基础也最重要的概念,如果不懂这个概念,只建议先去看python基础(当然如果你C++或者java方面的基础掌握的不错,兴许也能看懂大半)!

dataset类是pytorch中封装好的一个抽象类,我们需要自定义一个子类来继承该类,并实现其中的抽象方法(注:这是从C++和Java的视角去理解,在Python中,下面这三个函数被称为魔法函数->Magic Function),使用这个类前需要导入相应的包,即如下代码:

from torch.utils.data import Dataset

在Dataset类的子类中,应该有以下函数实现相应功能:

__init__:初始化每一个数据(对图片文件的路径进行一些设置,否则怎么取得这张图片?)
__getitem__:获取每一个数据(对照yolo项目也就是图片文件),及其对应的label(即图片中物体的类别,是ants还是bees?)
__len__:统计数据集中的数据数量(即该数据集总共有多少张图片?如果连有多少张图片都不知道,要怎么确认训练的次数?)


python的os库

os全称:operating system(对系统操作)

它提供的是各种 Python 程序与操作系统进行交互的接口。通过使用os模块,一方面可以方便地与操作系统进行交互,另一方面还可以极大增强代码的可移植性

本篇代码将会用到os库中的如下两个方法:

os.path.join()  (点方法名传送查看更细致的讲解)      #这个方法用来把传入的路径(可传入多个)拼接在一起,代码示例:

import os    #导入操作系统库

str1 = 'C:'
str2 = '目标检测'
str3 = 'yolov5'
path = os.path.join(str1,str2,str3)
print(path)#C:目标检测\yolov5


os.listdir()        #这个方法用来把传入的路径下的所有文件的文件名字存入一个列表并返回该列表,代码示例:

import os

str1 = 'C:'
str2 = '目标检测'
str3 = 'yolov5'
path = os.path.join(str1,str2,str3)
print(path)#C:目标检测\yolov5
my_list = os.listdir(path)
print(my_list)#['薯片.jpg', '饮料.png']<-注意我在自己电脑的这个路径下是确实放了这两个文件的,跟我一样的小白不要没输出一样的结果就懵圈、怀疑写错了。。。


python的PIL库

PIL全称:Python Imaging Library(图像馆?)

它能提供一些对图片文件的操作(例如打开、保存、显示图片,转换图片模式等等),因为代码需要,本篇只先初步认识一下,不细讲

本篇代码将会用到该库的open()->打开图片,和show()->显示图片

下面是导入PIL库的代码:

from PIL import Image# 导入python的PIL库


dataset子类代码

导库:

首先把用到的工具包和库导入一下

from torch.utils.data import Dataset# 导入dataset类
from PIL import Image# 导入python的PIL库
import os# 系统库

模板:

把模板写一下,子类继承 dataset这个抽象类,然后需要实现3个抽象方法,如下

class MyData(Dataset):    #这个子类必须继承Dataset类
    def __init__(self, root_dir, label_dir):    #初始化数据
        

    def __getitem__(self, idx):    #获取数据
        

    def __len__(self):    #获取数据集的长度
        

__init__:

这个函数是用来初始化数据的,作用前面提过。接收三个参数,这里的self是固定的参数(相当于C++或Java的this关键字,后面两个方法同理),另外传入两个参数分别表示父目录(上一级目录)和标签目录(即数据当前所在当前目录),然后利用os.path.join()方法把这两个组成一个相对路径,通过这个路径也就能找到数据集在硬盘中的存储位置了,最后用os.listdir()方法把这个路径下的所有(图片)文件的名字加后缀(例如:薯片.jpp)放入列表中,如下:

有了这个列表,我们就可以很方便的用父目录拼接图片名得到任意一张图片的存放路径了

    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 父目录
        self.lable_dir = label_dir  # 标签目录
        self.path = os.path.join(self.root_dir, self.lable_dir)  # 组成相对路径
        self.img_path = os.listdir(self.path)  # 把该路径下的所有文件名生成列表

__getitem__:

这个函数用来获取某个数据及其标签

首先self.img_path列表是上上一个__init__方法的成果,通过下标就能获得对应图片的名字

接着再用os.path.join()把父目录和文件名拼接起来,得到图片的路径

然后用Image.open()函数,传入该图片的路径,就能把这张图片打开,它是一个具体的对象,赋值给img变量作为返回值,label的话就更简单了,在__init__方法中已经作为成员变量存放好了,直接用self调用即可

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  # 通过下标索引,获取(图片)文件名
        img_item_path = os.path.join(self.root_dir, self.lable_dir, img_name)  # 组成(图片)文件的具体路径
        img = Image.open(img_item_path)  # 打开该(图片)文件
        label = self.lable_dir;
        return img, label  # 返回获取的(图片)文件和它的标签名

__len__:

这个比前两个简单多了,直接用len()方法求一下列表长度返回,稍微有点基础的都应该看得懂

    def __len__(self):  # 获取数据集的长度
        return len(self.img_path)

子类完整代码:

from torch.utils.data import Dataset# 导入dataset类
from PIL import Image# 导入python的PIL库
import os# 系统库


class MyData(Dataset):  # 继承Dataset类
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 父目录
        self.lable_dir = label_dir  # 标签目录
        self.path = os.path.join(self.root_dir, self.lable_dir)  # 组成相对路径
        self.img_path = os.listdir(self.path)  # 把该路径下的所有文件名生成列表

    def __getitem__(self, idx):
        img_name = self.img_path[idx]  # 通过下标索引,获取(图片)文件名
        img_item_path = os.path.join(self.root_dir, self.lable_dir, img_name)  # 组成(图片)文件的具体路径
        img = Image.open(img_item_path)  # 打开该(图片)文件
        label = self.lable_dir;
        return img, label  # 返回获取的(图片)文件和它的标签名

    def __len__(self):  # 获取数据集的长度
        return len(self.img_path)

测试:

root_dir1 = "MYDATA/train"#父目录
label_dir1 = "ants"#标签
test1 = MyData(root_dir1, label_dir1)#拼接得到ants的相对路径
print(test1.__len__())#输出一下ants下有多少张图片?结果是124
img1, label1 = test1.__getitem__(0)#调用方法获取第一张图片
img1.show()#用show()方法显示一下图片

尾声:

差点忘了分享链接,数据集是在这里下载的(点击跳转哔哩哔哩平台看视频简介)

补充:

在数据集文件夹中,通常还会创建一个label文件夹来存放对应图片的txt文本label,如下:

其中,image文件夹用来存放图片,label文件夹用来存放对应图片内容的类别(即标签),对应的文件名要相同(后缀不同),如下:

怎么去生成label文件中的txt文件呢?手动一个个改显然不现实,效率太低,用程序来实现就好了:

代码:

import os

root_dir = "MYDATA/train"
target_dir = "ants_image"
img_path = os.listdir(os.path.join(root_dir, target_dir))
label = target_dir.split('_')[0]
out_dir = "ants_label"#注意,这个文件夹要手动提前创建好
for i in img_path:
    file_name = i.split('.jpg')[0]
    with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
        #注意这个'w'参数是属于open函数的,不要写到os.path.join里面去了。。。血的教训
        f.write(label)

建议重开一个python文件写上述代码,运行如上代码后,ants_label文件夹中就会生成ants_image文件夹中的图片文件对应的label,原理也很简单,稍微琢磨一下能看懂的!至于bees_label文件夹,依葫芦画瓢就行了,补充内容结束!

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一个非常流行的深度学习框架,可以用于实现中文文本分类任务。下面是一个简单的示例代码,用于对中文文本进行分类: 首先,我们需要导入必要的库和模块: ```python import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torchtext import data from torchtext.vocab import Vectors ``` 接下来,我们定义一个类来构建我们的文本分类模型: ```python class TextClassifier(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim): super(TextClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.GRU(embedding_dim, hidden_dim, num_layers=2, bidirectional=True) self.fc = nn.Linear(hidden_dim * 2, output_dim) def forward(self, text): embedded = self.embedding(text) output, _ = self.rnn(embedded) hidden = torch.cat((output[-2, :, :], output[-1, :, :]), dim=1) return self.fc(hidden) ``` 然后,我们加载和预处理数据,这里使用了`torchtext`库来进行数据处理: ```python TEXT = data.Field(tokenize='jieba') LABEL = data.LabelField() dataset = data.TabularDataset('data.csv', format='csv', fields=[('text', TEXT), ('label', LABEL)]) train_data, test_data = dataset.split(split_ratio=0.9) TEXT.build_vocab(train_data, vectors=Vectors('vec.txt')) LABEL.build_vocab(train_data) train_iterator, test_iterator = data.BucketIterator.splits((train_data, test_data), batch_size=64, shuffle=True) ``` 接下来,我们定义模型参数和优化器,并进行训练和评估: ```python vocab_size = len(TEXT.vocab) embedding_dim = 100 hidden_dim = 256 output_dim = len(LABEL.vocab) model = TextClassifier(vocab_size, embedding_dim, hidden_dim, output_dim) optimizer = optim.Adam(model.parameters()) criterion = nn.CrossEntropyLoss() model.train() for epoch in range(10): for batch in train_iterator: text, label = batch.text, batch.label optimizer.zero_grad() output = model(text) loss = criterion(output, label) loss.backward() optimizer.step() model.eval() correct = 0 total = 0 for batch in test_iterator: text, label = batch.text, batch.label output = model(text) predicted = torch.argmax(output, dim=1) correct += (predicted == label).sum().item() total += label.size(0) accuracy = correct / total print(f'Accuracy: {accuracy:.4f}') ``` 以上就是使用PyTorch实现中文文本分类的基本过程。通过加载和预处理数据,构建模型,并通过训练和评估来对文本进行分类。当然,这只是一个简单的示例代码,你可以根据自己的需求进行调整和扩展。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值