数据集加载
创建一个自定义的PyTorch数据集类 MyData
,以及使用这个数据集类来加载图像和标签,并生成一个标签文件。
导入必要的库:
用于创建自定义数据集
用于用于图像处理
用于操作系统功能,如文件路径操作
from torch.utils.data import Dataset
from PIL import Image
import os
创建了一个MyData类用于存放数据集,继承Dateset的类:
class MyData(Dataset):
1>初始化数据集类的属性,加载相应的参数
def __init__(self, root_dir, label_dir):
2>两个全局变量,第一个是相对路径,第二个是标签
self.root_dir=root_dir
self.label_dir=label_dir
3>构建图像文件的完整路径
self.path = os.path.join(self.root_dir, self.label_dir)
4>获取目录下所有图像文件的列表
self.img_path = os.listdir(self.path)
索引 index:
在PyTorch中,__getitem__
和 __len__
是数据集类(继承自 torch.utils.data.Dataset
)中必须实现的两个特殊方法
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
1>用于按索引 idx
访问数据集中的单个样本。它应该返回一个数据样本(通常是图像和标签)
def __getitem__(self, idx):
2>从 self.img_path
(存储图像文件名的列表)中获取索引 idx
对应的图像文件名
img_name=self.img_path[idx]
3>使用 os.path.join
函数来构建完整的图像文件路径。它将 self.root_dir
(数据集的根目录)、self.label_dir
(标签目录)和 img_name
(图像文件名)连接起来
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
4>使用 PIL.Image.open
函数打开构建的图像路径,并返回一个 PIL.Image
对象
img=Image.open(img_item_path)
5>获取了文件中的标签信息,并将图片和标签返回
label=self.label_dir
return img, label
列表长度:
def __len__(self):
return len(self.img_path)
__len__
方法用于返回数据集中样本的数量,从0开始
生成文档:
root_dir="dataset/train"
ants_label_dir="ants_image"
ants_dataset=MyData(root_dir,ants_label_dir)
target_dir="ants_image"
label=target_dir.split('_')[0]
img_path=os.listdir(os.path.join(root_dir,target_dir))
out_dir="ants_label"
for i in img_path:
file_name=i.split('.jpg')[0]
with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:
f.write(label)
1>设置根目录和标签目录
root_dir = "dataset/train"
ants_label_dir = "ants_image"
这里定义了数据集的根目录 root_dir
和标签目录 ants_label_dir
2>创建数据集实例
ants_dataset = MyData(root_dir, ants_label_dir)
使用 root_dir
和 ants_label_dir
作为参数创建 MyData
类的实例 ants_dataset
3>设置目标目录和标签
target_dir = "ants_image"
label = target_dir.split('_')[0]
target_dir
是存放图像的目录,这里和 ants_label_dir
相同。label
是通过分割 target_dir
字符串并取第一部分得到的,这假设目录名是由多个单词组成,标签是第一个单词
4>获取图像文件列表
img_path = os.listdir(os.path.join(root_dir, target_dir))
使用 os.listdir
函数列出 root_dir
下 target_dir
目录中的所有文件
5>创建标签文件
out_dir = "ants_label"
for i in img_path:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir, out_dir, "{}.txt".format(file_name)), 'w') as f:
f.write(label)
对于 target_dir
目录中的每个图像文件:
- 使用
split('.jpg')[0]
从文件名中移除.jpg
扩展名,获取基本的文件名。 - 打开或创建一个位于
root_dir
下out_dir
目录中的文本文件,文件名与图像文件名相同(不包含扩展名),扩展名为.txt
。 - 写入之前从
target_dir
获取的label
。
整段代码:
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self,root_dir,label_dir):
self.root_dir=root_dir
self.label_dir=label_dir
self.path=os.path.join(self.root_dir,self.label_dir)
self.img_path=os.listdir(self.path)
def __getitem__(self, idx):
img_name=self.img_path[idx]
img_item_path=os.path.join(self.root_dir,self.label_dir,img_name)
img=Image.open(img_item_path)
label=self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
root_dir="dataset/train"
ants_label_dir="ants_image"
ants_dataset=MyData(root_dir,ants_label_dir)
target_dir="ants_image"
label=target_dir.split('_')[0]
img_path=os.listdir(os.path.join(root_dir,target_dir))
out_dir="ants_label"
for i in img_path:
file_name=i.split('.jpg')[0]
with open(os.path.join(root_dir,out_dir,"{}.txt".format(file_name)),'w') as f:
f.write(label)