【DL笔记】Tiny ImageNet 图像分类数据集的下载与使用

介绍

Tiny ImageNet Challenge 来源于斯坦福 CS231N 课程,共237M

Tiny Imagenet 有 200 个类。 每个类有 500 张训练图像、50 张验证图像和 50 张测试图像。

下载链接:

http://cs231n.stanford.edu/tiny-imagenet-200.zip

数据集使用

因为下载来的train跟val文件夹下图片存放位置不一样,所以路径需要一些变动

wnids.txt存放着标签

words.txt存放标签跟对应的描述,可以在few-shot或是zero-shot的时候用(下面的加载代码没有使用,只是做简单的分类任务)

train/label/xx/xx_boxes.txt与val/val_annotations.txt: 包括lable与boundingbox的标注,目标检测任务中使用(下面的加载代码没有使用,只是做简单的分类任务)

下面附上代码:

from typing import Any
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import glob
import argparse
from PIL import Image

class TrainTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None) -> None:
        super().__init__()
        self.filenames = glob.glob(root + "\\train\*\*\*.JPEG")
        self.transform = transform
        self.id_dict = id
    
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx: Any) -> Any:
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        label = self.id_dict[img_path.split('\\')[-3]]
        if self.transform:
            image = self.transform(image)
        return image, label

class ValTinyImageNet(Dataset):
    def __init__(self, root, id, transform=None):
        self.filenames = glob.glob(root + "\\val\images\*.JPEG")
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        for i, line in enumerate(open(root + '\\val\\val_annotations.txt', 'r')):
            a = line.split('\t')
            img, cls_id = a[0], a[1]
            self.cls_dic[img] = self.id_dict[cls_id]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == 'L':
            image = image.convert('RGB')
        label = self.cls_dic[img_path.split('\\')[-1]]
        if self.transform:
            image = self.transform(image)
        return image, label

def load_tinyimagenet(args):
    batch_size = args.batch_size
    nw = args.workers
    root = 'E:\PythonProjects\dataset\\tiny-imagenet-200'
    id_dic = {}
    for i, line in enumerate(open(root+'\wnids.txt','r')):
        id_dic[line.replace('\n', '')] = i
    num_classes = len(id_dic)
    data_transform = {
        "train": transforms.Compose([transforms.Resize(224),
                                     transforms.RandomCrop(224, padding=4),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        "val": transforms.Compose([transforms.Resize(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    train_dataset = TrainTinyImageNet(root, id=id_dic, transform=data_transform["train"])
    val_dataset = ValTinyImageNet(root, id=id_dic, transform=data_transform["val"])

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw)
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw)
    
    print("TinyImageNet Loading SUCCESS"+
          "\nlen of train dataset: "+str(len(train_dataset))+
          "\nlen of val dataset: "+str(len(val_dataset)))
    
    return train_loader, val_loader, num_classes


if __name__ == '__main__':
    parser = argparse.ArgumentParser("parameters")
    parser.add_argument("--batch-size", type=int, default=120, help="number of batch size, (default, 512)")
    parser.add_argument('--workers', type=int, default=7)
    
    parser.add_argument('--seed', default=42, type=int, nargs='+',
                    help='seed for initializing training. ')

    args = parser.parse_args()
    train, val, num_classes = load_tinyimagenet(args)

workers是数据预加载的参数,可以根据cpu情况自行更改

  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值