小土堆:Pytorch深度学习:Dataset类

作者因为课题需求,刚接触Pytorch,这里是我的学习笔记分享,一方面作为笔记记录,加深印象,另一方面也是希望可以帮助大家。

课程来源:

PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

https://www.bilibili.com/video/BV1hE411t7RN?p=7&vd_source=1633322d160e1c2471aa2b3eb28bde0a

蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq

代码讲解:主要内容是数据集加载

导入package:

from torch.utils.data import Datasetfrom PIL import Imageimport os

创建了一个MyData类用于存放数据集,继承Dateset的类:

class MyData(Dataset):

在实际项目时,要构建自己的数据集,需要继承Dateset的类,Pytorch才能读取使用。想要实现这个类,必须要重写3个方法:init(self, 参数…)、 getitem(self, index)、len(self)。

1)进行初始化,加载相应的参数,self代表实例本身

 def __init__(self,root_dir,label_dir):

2)两个全局变量,第一个是相对路径,第二个是标签(就是文件夹名字)

 self.root_dir=root_dir self.label_dir=label_dir

3)通过os函数,把目录和文件名合成一个路径

self.path=os.path.join(self.root_dir,self.label_dir)

4)img_path通过os中的listdir方法返回一个包含全部文件名的列表,便于调用等

self.img_path=os.listdir(self.path)

索引 index:

    def __getitem__(self, idx):                                                         img_name=self.img_path[idx]        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)        img=Image.open(img_item_path)        label=self.label_dir        return img, label      def __len__(self):        return len(self.img_path)

数据集本质应当是所有数据样本的一个列表,因此每个样本都有对应的索引index。我们取用一个样本最简单的方式就是用该样本的index从数据列表中把它取出来。__getitem__就是做这样一件事。

1)一般如果想使用索引访问元素时,就可以在类中定义这个方法(__getitem__(self, idx) )。

 def __getitem__(self, idx):

2)读取列表,获得图片名称

img_name=self.img_path[idx]

3)拼接相对地址, 文件夹名称, 图片名称的最终路径

img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)

4)读取图片

img=Image.open(img_item_path)

5)获取了文件中的标签信息,并将图片和标签返回

label=self.label_dirreturn img, label

列表长度:

用于获取列表的长度,返回数字,注意从0开始

def __len__(self):        return len(self.img_path)

生成文档:

root_dir="dataset/train"ants_label_dir="ants_image"ants_dataset=MyData(root_dir,ants_label_dir)target_dir="ants_image"label=target_dir.split('_')[0]img_path=os.listdir(os.path.join(root_dir,target_dir))out_dir="ants_label"for i in img_path:    file_name=i.split('.jpg')[0]    with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:        f.write(label)

label 标签 target 目标 

其中str.split('_')[0]是切片函数,就是将str这个字符串每有一个_切一刀,然后取第【i】段

举个例子:

str="linlinxuezhang_dashuaige_jushiwushaung"

str.split('_')[0] 就是linlinxuezhang

str.split('_')[1] 就是dashuaige

写文件操作,将标签写入root_dir/out_dir并创建txt文件,文件内容为label
 'w': 打开一个文件只用于写入。如果该文件已存在则将其覆盖。如果该文件不存在,创建新文件。

整段代码如下:

from torch.utils.data import Datasetfrom PIL import Imageimport osclass MyData(Dataset):    def __init__(self,root_dir,label_dir):         self.root_dir=root_dir        self.label_dir=label_dir        self.path=os.path.join(self.root_dir,self.label_dir)        self.img_path=os.listdir(self.path)     def __getitem__(self, idx):                                                         img_name=self.img_path[idx]        img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)        img=Image.open(img_item_path)        label=self.label_dir        return img, label      def __len__(self):        return len(self.img_path) root_dir="dataset/train"ants_label_dir="ants_image"ants_dataset=MyData(root_dir,ants_label_dir)target_dir="ants_image"label=target_dir.split('_')[0]img_path=os.listdir(os.path.join(root_dir,target_dir))out_dir="ants_label"for i in img_path:    file_name=i.split('.jpg')[0]    with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:        f.write(label)
  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值