作者因为课题需求,刚接触Pytorch,这里是我的学习笔记分享,一方面作为笔记记录,加深印象,另一方面也是希望可以帮助大家。
课程来源:
PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】
https://www.bilibili.com/video/BV1hE411t7RN?p=7&vd_source=1633322d160e1c2471aa2b3eb28bde0a
蚂蚁蜜蜂/练手数据集:链接: https://pan.baidu.com/s/1jZoTmoFzaTLWh4lKBHVbEA 密码: 5suq
代码讲解:主要内容是数据集加载
导入package:
from torch.utils.data import Dataset
from PIL import Image
import os
创建了一个MyData类用于存放数据集,继承Dateset的类:
class MyData(Dataset):
在实际项目时,要构建自己的数据集,需要继承Dateset的类,Pytorch才能读取使用。想要实现这个类,必须要重写3个方法:init(self, 参数…)、 getitem(self, index)、len(self)。
1)进行初始化,加载相应的参数,self代表实例本身
def __init__(self,root_dir,label_dir):
2)两个全局变量,第一个是相对路径,第二个是标签(就是文件夹名字)
self.root_dir=root_dir
self.label_dir=label_dir
3)通过os函数,把目录和文件名合成一个路径
self.path=os.path.join(self.root_dir,self.label_dir)
4)img_path通过os中的listdir方法返回一个包含全部文件名的列表,便于调用等
self.img_path=os.listdir(self.path)
索引 index:
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
数据集本质应当是所有数据样本的一个列表,因此每个样本都有对应的索引index。我们取用一个样本最简单的方式就是用该样本的index从数据列表中把它取出来。__getitem__就是做这样一件事。
1)一般如果想使用索引访问元素时,就可以在类中定义这个方法(__getitem__(self, idx) )。
def __getitem__(self, idx):
2)读取列表,获得图片名称
img_name=self.img_path[idx]
3)拼接相对地址, 文件夹名称, 图片名称的最终路径
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
4)读取图片
img=Image.open(img_item_path)
5)获取了文件中的标签信息,并将图片和标签返回
label=self.label_dir
return img, label
列表长度:
用于获取列表的长度,返回数字,注意从0开始
def __len__(self):
return len(self.img_path)
生成文档:
root_dir="dataset/train"
ants_label_dir="ants_image"
ants_dataset=MyData(root_dir,ants_label_dir)
target_dir="ants_image"
label=target_dir.split('_')[0]
img_path=os.listdir(os.path.join(root_dir,target_dir))
out_dir="ants_label"
for i in img_path:
file_name=i.split('.jpg')[0]
with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:
f.write(label)
label 标签 target 目标
其中str.split('_')[0]是切片函数,就是将str这个字符串每有一个_切一刀,然后取第【i】段
举个例子:
str="linlinxuezhang_dashuaige_jushiwushaung"
str.split('_')[0] 就是linlinxuezhang
str.split('_')[1] 就是dashuaige
写文件操作,将标签写入root_dir/out_dir并创建txt文件,文件内容为label
'w': 打开一个文件只用于写入。如果该文件已存在则将其覆盖。如果该文件不存在,创建新文件。
整段代码如下:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir="dataset/train"
ants_label_dir="ants_image"
ants_dataset=MyData(root_dir,ants_label_dir)
target_dir="ants_image"
label=target_dir.split('_')[0]
img_path=os.listdir(os.path.join(root_dir,target_dir))
out_dir="ants_label"
for i in img_path:
file_name=i.split('.jpg')[0]
with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:
f.write(label)