根据数据集的样子,来定义我们的class类
# -*- coding: utf-8 -*-
# @Time : 2022/4/22 10:45
# @Author : 李新宇
# @FileName: read_data.py
# @Software: PyCharm
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
#普通变量不能给另一个函数使用,self相当于指定一个类中的全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
#获得所有图片的列表
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
#获得单个的图片
img_name = self.img_path[idx]
#现在知道了名称,还需要加上单个图片路径
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
#读取图片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
#数据集到底有多长
def __len__(self):
return len(self.img_path)
#创建实例
#蚂蚁数据集
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
ants_dateset = MyData(root_dir,ants_label_dir)
#想获取蚂蚁数据集的第一个变量
ants_dateset[0]
# Out[5]:
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512 at 0x1FC1C874D60>,'ants_image')
#是因为在 def __getitem__(self, idx)中,定义了给idx,return:img+label
#蜜蜂数据集
bees_label_dir = "bees"
bees_dateset = MyData(root_dir,ants_label_dir)
#将数据集合在一起
train_dataset = ants_dateset + bees_dateset
可通过控制台载入查看各类图片属性! 使用绝对路径是方便移植和部署
os操作,将文件夹整个变为一个列表
通过[idx]索引
若要通过idx去获取图片的话,首先需要去创建图片地址的一个列表
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
def __getitem__(self, idx):
全部定义完成,实例创建成功
# -*- coding: utf-8 -*-
# @Time : 2022/4/22 10:45
# @Author : 李新宇
# @FileName: read_data.py
# @Software: PyCharm
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
#普通变量不能给另一个函数使用,self相当于指定一个类中的全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir,self.label_dir)
#获得所有图片的列表
self.img_path = os.listdir(self.path)
def __getitem__(self, idx):
#获得单个的图片
img_name = self.img_path[idx]
#现在知道了名称,还需要加上单个图片路径
img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)
#读取图片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
#数据集到底有多长
def __len__(self):
return len(self.img_path)
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
#创建实例
ants_dateset = MyData(root_dir,ants_label_dir)
想获取蚂蚁数据集的第一个变量
ants_dateset[0]
因为在 def getitem(self, idx)中,定义了给idx,return:img+label
得
ants_dateset[0]
Out[5]:
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=768x512 at 0x1FC1C874D60>,
'ants')
将其return赋值给img和label:
img,label = ants_dateset[0]
img.show()
蜜蜂数据集+俩数据集合在一起
#蜜蜂数据集
bees_label_dir = "ants"
bees_dateset = MyData(root_dir,ants_label_dir)
train_dataset = ants_dateset + bees_dateset
len(ants_dateset)
Out[10]: 124
len(bees_dateset)
Out[11]: 124
len(train_dataset)
Out[12]: 248
尝试将图片label放入另一个文件夹,新建文件夹。
不同的数据集格式不同
这里的文件格式,如下
# -*- coding: utf-8 -*-
# @Time : 2022/4/22 12:13
# @Author : 李新宇
# @FileName: rename_dateset.py
# @Software: PyCharm
import os
root_dir =r'C:\Users\Eden\Desktop\tudui_pytorch\hymenoptera_data\train'
target_dir = 'ants_image'
img_path = os.listdir(os.path.join(root_dir,target_dir))
label = target_dir.split('_')[0]
out_dir = 'ants_label'
for i in img_path:
file_name = i.split('.jpg')[0]
with open(os.path.join(root_dir,out_dir,'{}.txt'.format(file_name)),'w' ) as f:
f.write(label)