torchvision.datasets.ImageFolder调用结构:
对于简单的图像分类任务,并不需要自己定义一个 Dataset类,可以直接调用 torchvision.datasets.ImageFolder 返回训练数据与标签。
数据集应满足pytorch的格式要求,即将数据集分割为训练集和测试集,并将数据和标签分别放入不同的文件夹;
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
同时,应兼顾按比例划分训练集,测试集及验证集的需求。
下面的函数,将人眼睁闭数据集转换为pytorch指定的结构;
原始数据集:
调用代码示例:
import os
import shutil
import random
class PictureClassifier(object):
def __init__(self, img_dir, target_dir, categories, train_percent, validate_percent, test_percent):
self.img_dir = img_dir
self.target_dir = target_dir
self.categories = categories
self.train_percent = train_percent
self.validate_percent = validate_percent
self.test_percent = test_percent
for category in categories:
os.makedirs(os.path.join(target_dir, 'train', category))
os.makedirs(os.path.join(target_dir, 'validate', category))
os.makedirs(os.path.join(target_dir, 'test', category))
#定义通过图片名获取标签的方法,返回标签
def getLabelByFileName(self, filename):
pass
#检验被遍历对象,是否为需要处理图片的方法,返回true或false
def isPic(self, filename):
pass
#遍历img_dir下的所有文件,逐一进行操作
def classify(self):
for root, dirs, files in os.walk(self.img_dir):
for file in files:
# 打印所有文件对象路径:
# print(os.path.join(root, file))
# 该file所在的路径
# print(root)
fileName = file
if self.isPic(fileName):
label = self.getLabelByFileName(fileName)
if random.random() < self.train_percent:
shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'train', label, file))
elif random.random() < self.validate_percent:
shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'validate', label, file))
else:
shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'test', label, file))
else:
continue
class MyPictureClassifier(PictureClassifier):
def __init__(self, img_dir, target_dir, categories,train_percent, validate_percent, test_percent):
super(MyPictureClassifier, self).__init__(img_dir, target_dir, categories,train_percent, validate_percent, test_percent)
def getLabelByFileName(self, filename):
#数据集第四个位置为标签名:
num_str = filename.split('_')[4]
if num_str=="0":
return 'close'
else:
return 'open'
def isPic(self, filename):
return filename.endswith('.png')
# 图片所在的文件夹
img_dir = 'D:\mrlEyes_2018_01'
# 将图片转换后存放的文件夹
target_dir = 'D:\eyeDataSet'
# 类别信息
categories = ['open', 'close']
worker=MyPictureClassifier(img_dir,target_dir,categories,0.8,0.1,0.1)
worker.classify()
转换后: