【pytorch】图片分类问题处理一般数据集,使其满足torchvision.datasets.ImageFolder调用结构

65 篇文章 3 订阅
44 篇文章 4 订阅

torchvision.datasets.ImageFolder调用结构:

对于简单的图像分类任务,并不需要自己定义一个 Dataset类,可以直接调用 torchvision.datasets.ImageFolder 返回训练数据与标签。

数据集应满足pytorch的格式要求,即将数据集分割为训练集和测试集,并将数据和标签分别放入不同的文件夹;

root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png

root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png

同时,应兼顾按比例划分训练集,测试集及验证集的需求。
下面的函数,将人眼睁闭数据集转换为pytorch指定的结构;
原始数据集:
在这里插入图片描述

调用代码示例:

import os
import shutil
import random
class PictureClassifier(object):
    def __init__(self, img_dir, target_dir, categories, train_percent, validate_percent, test_percent):
        self.img_dir = img_dir
        self.target_dir = target_dir
        self.categories = categories
        self.train_percent = train_percent
        self.validate_percent = validate_percent
        self.test_percent = test_percent
        for category in categories:
            os.makedirs(os.path.join(target_dir, 'train', category))
            os.makedirs(os.path.join(target_dir, 'validate', category))
            os.makedirs(os.path.join(target_dir, 'test', category))
    #定义通过图片名获取标签的方法,返回标签
    def getLabelByFileName(self, filename):
        pass
    #检验被遍历对象,是否为需要处理图片的方法,返回truefalse
    def isPic(self, filename):
        pass
    #遍历img_dir下的所有文件,逐一进行操作
    def classify(self):
        for root, dirs, files in os.walk(self.img_dir):
            for file in files:
                # 打印所有文件对象路径:
                # print(os.path.join(root, file))
                # 该file所在的路径
                # print(root)
                fileName = file
                if self.isPic(fileName):
                    label = self.getLabelByFileName(fileName)
                    if random.random() < self.train_percent:
                        shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'train', label, file))
                    elif random.random() < self.validate_percent:
                        shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'validate', label, file))
                    else:
                        shutil.copy(os.path.join(root, file), os.path.join(self.target_dir, 'test', label, file))
                else:
                    continue


class MyPictureClassifier(PictureClassifier):
    def __init__(self, img_dir, target_dir, categories,train_percent, validate_percent, test_percent):
        super(MyPictureClassifier, self).__init__(img_dir, target_dir, categories,train_percent, validate_percent, test_percent)
    def getLabelByFileName(self, filename):
        #数据集第四个位置为标签名:
        num_str = filename.split('_')[4]
        if num_str=="0":
            return 'close'
        else:
            return 'open'


    def isPic(self, filename):
        return filename.endswith('.png')

# 图片所在的文件夹
img_dir = 'D:\mrlEyes_2018_01'

# 将图片转换后存放的文件夹
target_dir = 'D:\eyeDataSet'

# 类别信息
categories = ['open', 'close']

worker=MyPictureClassifier(img_dir,target_dir,categories,0.8,0.1,0.1)
worker.classify()

转换后:
在这里插入图片描述
在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

颢师傅

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值