Pytorch自定义数据集 ImageNet

由于Pytorch不支持内置的ImageNet数据集,因此我们需要自定义数据集。有两种方式

1、使用ImageFolder

ImageFolder需要数据集有良好的结构,train和test下分别包含相同类别的文件夹,每个文件夹存放一类图像,也就是这样

——ImageNet
  ——train
  	——cls1
  		——cls1_00.jpg
  		——cls1_01.jpg
  		...
  		——cls1_59.jpg
  	——cls2
  	——clsn
  ——test
  	——cls1
  		——cls1_60.jpg
  		——cls1_61.jpg
  		...
  		——cls1_100.jpg
  	——cls2
  	——clsn

此时把.../ImageNet/train 或者 .../ImageNet/test/ 当作imagenet_root传入ImageFolder即可

from torchvision.datasets import ImageFolder
imagenet_train = ImageFolder(imagenet_root, transform=transform_imagenet_train)
train_iter= DataLoader(imagenet_train, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                               pin_memory=True)

2、使用自定义的数据集

当我们使用Dataset 时,必须定义__init____getitem__以及__len__三个成员函数。

  • __init__ :初始化,进行数据集的准备工作
  • __len__ :返回数据集的大小
  • __getitem__:根据索引(必要的参数)从数据集中提取数据。索引的大小为[0,self.__len__())

下面是从ImageNet全集中随机挑选100类(当然也可以任意指定数量)的代码,可以保存每一次都选取了哪些类用作训练。使用子集训练的原因是进行有效性实验,以及计算资源的限制。

# ImageNet100.py
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as T
import os
import random
from PIL import Image


class ImageNet100(Dataset):
    def __init__(self, root_path="../Dataset/ImageNet", shuffle=True, train=False, new_data=False, transform = T.ToTensor()):
        super().__init__()
        self.length = 0
        self.transform = transform
        
        if train:
            path = os.path.join(root_path, "train")
        else:
            path = os.path.join(root_path, "val")
        cls_list = sorted(os.listdir(path))
        
        cls100_index_file = os.path.join(root_path, "cls_index.npy")  # 存放100个子类的索引
        if (not os.path.exists(cls100_index_file)) or new_data:
            index_all = np.arange(len(cls_list))
            index = np.random.choice(index_all, 100, replace=False)
            np.save(cls100_index_file, index)
        else:
            print("file exists, loading")
            index = np.load(cls100_index_file)
        
        self.img_label_list = []
        for label, idx in enumerate(index):
            sub_cls = os.path.join(path, cls_list[idx])
            for img in os.listdir(sub_cls):
                self.img_label_list.append((os.path.join(sub_cls, img), label))
                self.length +=1

        if shuffle:
            random.shuffle(self.img_label_list)

        print("dateset size: ", self.length)
        
    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        img_path, label = self.img_label_list[idx]
        with open(img_path, 'rb') as f:
            img = Image.open(f)
            data = self.transform(img.convert('RGB'))
        # data = self.transform(Image.open(img_path)) #数据集不干净,有灰度图时报错
        return data, label

ImageNet数据集报错 output with shape [1, 224, 224] doesn’t match the broadcast shape [3, 224, 224]…

出现这个的错误的原因主要是数据集不够干净,数据集中同时含有RGB图像和灰度图像,可能是下载的地址不正确。这种时候__getitem__中注释的部分就不可行了。
同样的数据集用ImageFolder就可以正常运行,于是我阅读了ImageFolder的源码,找到了读取图像的部分

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

此时,一切问题都解决了。

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值