ImageNet是图像分类领域常用的数据集。MiniImageNet是ImageNet的简化版数据集,新的方法可以在MiniImageNet上测试运行。
1. 下载地址
ImageNet约100GB可以从官网下载,MiniImageNet约3GB,下载地址,密码: hl31。
2. 原始ImageNet数据集
使用原始的ImageNet数据集时,调用torchvision.datasets.DatasetFolder
类即可,使用方法如下:
def load_dataset(args):
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_transform = transforms.Compose([
transforms.Resize(args.size),
transforms.RandomResizedCrop(args.size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])
val_transform = transforms.Compose([
transforms.Resize(args.size),
transforms.CenterCrop(args.size),
transforms.ToTensor(),
normalize,
])
train_dataset = datasets.ImageFolder(traindir, transform=train_transform)
val_dataset = datasets.ImageFolder(valdir, transform=val_transform)
if args.distributed:
train_sampler = data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
train_loader = data.DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
val_loader = data.DataLoader(
val_dataset, batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True, drop_last=True)
return train_loader, val_loader, train_sampler
因此,在训练和测试网络之前调用load_dataset即可:
train_loader, val_loader, train_sampler = load_dataset(args)
3. 变换ImageNet数据集
但是,在一些情况下,我们可能需要根据自己的要求设计train_loader
和val_loader
的数据格式,那应该怎么做呢?
我们可以创建一个新的类,例如我创建的类命名为TransformedImageFolder
,继承torchvision.datasets.DatasetFolder
,在__init__
中初始化父类,并增加自己的参数。
我这里实现的功能是,对于每一张图片,返回的结果为原始图片旋转[0,90,180,270]
度的结果拼接在一起,也就是说数据集包含原始数据集和变换后的数据集,大小为原始数据集的四倍。这种旋转变换在对比学习等领域经常需要用到。
具体实现代码如下:
class TransformedImageFolder(datasets.DatasetFolder):
def __init__(self, root, train=False, transform=None, target_transform=None, loader=default_loader, is_valid_file=None,):
super(TransformedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
self.train = train
def __getitem__(self, index):
path, target = self.samples[index]
img = self.loader(path)
img = np.array(img)
if self.train:
if np.random.rand() < 0.5:
img = img[:,::-1,:]
img0 = np.rot90(img, 0).copy()
img1 = np.rot90(img, 1).copy()
img2 = np.rot90(img, 2).copy()
img3 = np.rot90(img, 3).copy()
img0 = Image.fromarray(img0)
img1 = Image.fromarray(img1)
img2 = Image.fromarray(img2)
img3 = Image.fromarray(img3)
img0 = self.transform(img0)
img1 = self.transform(img1)
img2 = self.transform(img2)
img3 = self.transform(img3)
img = torch.stack([img0,img1,img2,img3])
return img, target
4. 原始MiniImageNet数据集
对于MiniImageNet数据集,它实际上是从ImageNet中抽取了一部分数据,可以参考如下方法定义:
class MiniImagenet(Dataset):
"""mini imagenet"""
def __init__(self,
root_dir: str,
csv_name: str,
json_path: str,
transform=None):
images_dir = os.path.join(root_dir, "images")
assert os.path.exists(images_dir), "dir:'{}' not found.".format(images_dir)
assert os.path.exists(json_path), "file:'{}' not found.".format(json_path)
self.label_dict = json.load(open(json_path, "r"))
csv_path = os.path.join(root_dir, csv_name)
assert os.path.exists(csv_path), "file:'{}' not found.".format(csv_path)
csv_data = pd.read_csv(csv_path)
self.total_num = csv_data.shape[0]
self.img_paths = [os.path.join(images_dir, i)for i in csv_data["filename"].values]
self.img_label = [self.label_dict[i][0] for i in csv_data["label"].values]
self.labels = set(csv_data["label"].values)
self.transform = transform
def __len__(self):
return self.total_num
def __getitem__(self, item):
img = Image.open(self.img_paths[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.img_paths[item]))
label = self.img_label[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
调用方法和ImageNet类似,但稍有区别,因为传入的参数不一样
def load_minidataset(args):
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
data_root = args.data
json_path = "utils/classes_name.json"
train_dataset = MiniImagenet(root_dir=data_root,
csv_name="new_train.csv",
json_path=json_path,
transform=data_transform["train"])
val_dataset = MiniImagenet(root_dir=data_root,
csv_name="new_val.csv",
json_path=json_path,
transform=data_transform["val"])
# check num_classes
if args.num_classes != len(train_dataset.labels):
raise ValueError("dataset have {} classes, but input {}".format(len(train_dataset.labels), args.num_classes))
batch_size = args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader = torch.utils.data.DataLoader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
if args.distributed:
train_sampler = data.distributed.DistributedSampler(train_dataset)
else:
train_sampler = None
return train_loader, val_loader, train_sampler
5. 变换MiniImageNet数据集
同样地,如果我们可能需要根据自己的要求设计train_loader
和val_loader
的数据格式,也可以创建一个新的类。
例如我创建的类命名为TransMiniImageFolder
,继承先前的MiniImagenet
。这里实现的功能仍然是是,对于每一张图片,返回的结果为原始图片旋转[0,90,180,270]
度的结果拼接在一起。
具体实现代码如下:
class TransMiniImagenet(MiniImagenet):
"""transformed mini imagenet"""
def __init__(self, root_dir: str, csv_name: str, json_path: str, transform=None, train=False):
super(TransMiniImagenet, self).__init__(root_dir, csv_name, json_path, transform)
self.train = train
def __getitem__(self, item):
img = Image.open(self.img_paths[item])
# RGB为彩色图片,L为灰度图片
if img.mode != 'RGB':
raise ValueError("image: {} isn't RGB mode.".format(self.img_paths[item]))
label = self.img_label[item]
img = np.array(img)
if self.train:
if np.random.rand() < 0.5:
img = img[:,::-1,:]
img0 = np.rot90(img, 0).copy()
img1 = np.rot90(img, 1).copy()
img2 = np.rot90(img, 2).copy()
img3 = np.rot90(img, 3).copy()
img0 = Image.fromarray(img0)
img1 = Image.fromarray(img1)
img2 = Image.fromarray(img2)
img3 = Image.fromarray(img3)
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
img2 = self.transform(img2)
img3 = self.transform(img3)
img = torch.stack([img0,img1,img2,img3])
return img, label