由于Pytorch不支持内置的ImageNet数据集,因此我们需要自定义数据集。有两种方式
1、使用ImageFolder
ImageFolder需要数据集有良好的结构,train和test下分别包含相同类别的文件夹,每个文件夹存放一类图像,也就是这样
——ImageNet
——train
——cls1
——cls1_00.jpg
——cls1_01.jpg
...
——cls1_59.jpg
——cls2
——clsn
——test
——cls1
——cls1_60.jpg
——cls1_61.jpg
...
——cls1_100.jpg
——cls2
——clsn
此时把.../ImageNet/train
或者 .../ImageNet/test/
当作imagenet_root
传入ImageFolder
即可
from torchvision.datasets import ImageFolder
imagenet_train = ImageFolder(imagenet_root, transform=transform_imagenet_train)
train_iter= DataLoader(imagenet_train, batch_size=batch_size, shuffle=True, num_workers=num_workers,
pin_memory=True)
2、使用自定义的数据集
当我们使用Dataset
时,必须定义__init__
、__getitem__
以及__len__
三个成员函数。
__init__
:初始化,进行数据集的准备工作__len__
:返回数据集的大小__getitem__
:根据索引(必要的参数)从数据集中提取数据。索引的大小为[0,self.__len__())
下面是从ImageNet全集中随机挑选100类(当然也可以任意指定数量)的代码,可以保存每一次都选取了哪些类用作训练。使用子集训练的原因是进行有效性实验,以及计算资源的限制。
# ImageNet100.py
from torch.utils.data import Dataset
import numpy as np
import torchvision.transforms as T
import os
import random
from PIL import Image
class ImageNet100(Dataset):
def __init__(self, root_path="../Dataset/ImageNet", shuffle=True, train=False, new_data=False, transform = T.ToTensor()):
super().__init__()
self.length = 0
self.transform = transform
if train:
path = os.path.join(root_path, "train")
else:
path = os.path.join(root_path, "val")
cls_list = sorted(os.listdir(path))
cls100_index_file = os.path.join(root_path, "cls_index.npy") # 存放100个子类的索引
if (not os.path.exists(cls100_index_file)) or new_data:
index_all = np.arange(len(cls_list))
index = np.random.choice(index_all, 100, replace=False)
np.save(cls100_index_file, index)
else:
print("file exists, loading")
index = np.load(cls100_index_file)
self.img_label_list = []
for label, idx in enumerate(index):
sub_cls = os.path.join(path, cls_list[idx])
for img in os.listdir(sub_cls):
self.img_label_list.append((os.path.join(sub_cls, img), label))
self.length +=1
if shuffle:
random.shuffle(self.img_label_list)
print("dateset size: ", self.length)
def __len__(self):
return self.length
def __getitem__(self, idx):
img_path, label = self.img_label_list[idx]
with open(img_path, 'rb') as f:
img = Image.open(f)
data = self.transform(img.convert('RGB'))
# data = self.transform(Image.open(img_path)) #数据集不干净,有灰度图时报错
return data, label
ImageNet数据集报错 output with shape [1, 224, 224] doesn’t match the broadcast shape [3, 224, 224]…
出现这个的错误的原因主要是数据集不够干净,数据集中同时含有RGB图像和灰度图像,可能是下载的地址不正确。这种时候__getitem__
中注释的部分就不可行了。
同样的数据集用ImageFolder
就可以正常运行,于是我阅读了ImageFolder
的源码,找到了读取图像的部分
def pil_loader(path):
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
with open(path, 'rb') as f:
img = Image.open(f)
return img.convert('RGB')
此时,一切问题都解决了。