ImageFolder类上添加划分数据集和打乱数据的功能

问题描述

ImageFolder是一个非常有用的类,只要数据集按照要求规范文件,就可以很轻松的,得到 文件路径和类型 的元祖。同时加载DataLoader 也非常方便,但是在实际用的时候发现缺少了划分数据集的功能,并且是按照数据也是按照文件夹依次得到的,这对划分数据集上非常不利的。

解决

通过阅读ImageFolder源码并在其基础上继承并添加自己的功能。

import time
import torch
import visdom
import torchvision
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader  # DataLoader可以实现一个加载一个batch的功能
import random

def denormalize(x_hat):
    """
    将normalize后的图片返回原来正常的图片
    :param x_hat:
    :return:
    """
    # x_hat[channel] = (x[channel] - mean[channel]) / std[channel]
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    # x:[c,h,w] mean:[3]->[3,1,1]
    mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
    std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
    x = x_hat * std + mean
    return x

class Pokemon(ImageFolder):
    def __init__(self,root,model):
        super(Pokemon, self).__init__(root)
        # 直接类内定义transforms的compose变换
        tf = transforms.Compose([
            # 自定义函数
            transforms.Resize((280, 280)),
            transforms.RandomRotation(15),
            # 可能会有边缘没被看到,而出现黑边, 需要进行中心裁剪
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            # 归一化,数值是有imagenet统计得出的,更有普遍性
            # 输出不再是0-1之间分布,而是在-1 - 1之间分布
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.transform = tf
        # 打乱数据
        random.seed(1234)
        random.shuffle(self.imgs)
        random.seed(1234)
        random.shuffle(self.samples)
        if model == "train":  # 60%
            self.imgs = self.imgs[:int(0.6*len(self.imgs))]
            self.samples = self.samples[:int(0.6*len(self.samples))]
        elif model == "val":  # 20%
            self.imgs = self.imgs[int(0.6 * len(self.imgs)):int(0.8 * len(self.imgs))]
            self.samples = self.samples[int(0.6 * len(self.imgs)):int(0.8 * len(self.imgs))]
        else:  # 20% = 80% -> 100%
            self.imgs = self.imgs[int(0.8 * len(self.imgs)):]
            self.samples = self.samples[int(0.8 * len(self.imgs)):]

if __name__ == '__main__':
    viz = visdom.Visdom()
    # ImageFolder会自动完成文件夹的编码
    db = Pokemon(root="pokemon", model='train')
    print(db.class_to_idx)
    # bath进行加载
    loader = DataLoader(db, batch_size=32)
    for x, y in loader:  # 这里的x,y是一个batch的
        viz.images(denormalize(x), nrow=8, win="batch", opts=dict(title="batch"))
        print(y)
        viz.text(str(y.tolist()), win="label", opts=dict(title="batch-label"))
        time.sleep(10)

输出结果

Setting up a new session...
{'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
tensor([4, 3, 1, 0, 1, 3, 2, 3, 1, 0, 1, 3, 2, 0, 0, 4, 0, 0, 3, 0, 0, 0, 1, 4,
        0, 2, 3, 2, 2, 2, 4, 3])

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值