ResNeXt 之 输入数据预处理代码详解

了解深度学习的同学都知道,在写代码的时候,主要的两个部分就是网络搭建和数据预处理工作,那么这就需要我们不断地积累才能更好地使用,还不了解ILSVRC2012数据集形式的要先了解其形式,基本形式,就是 类别文件夹(n0xxxxx)->图片名 ,想要详细了解的,给大家推荐两篇博客,一定要结合来看。

链接:https://blog.csdn.net/u012024357/article/details/90679222

          https://blog.csdn.net/tjuyanming/article/details/91354244

大家了解了数据集格式后,接下来我会给大家介绍ResNeXt的数据预处理工作是怎么进行的,我在代码部分的关键部分都做了详细的注释,大家一定要看代码。

from torchvision import transforms, datasets
import os
import torch
from PIL import Image
import scipy.io as scio

IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm']

def ImageNetData(args):
# data_transform, pay attention that the input of Normalize() is Tensor and the input of RandomResizedCrop() or RandomHorizontalFlip() is PIL Image
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }
    image_datasets = {}
    #image_datasets['train'] = datasets.ImageFolder(os.path.join(args.data_dir, 'ILSVRC2012_img_train'), data_transforms['train'])
    
    #参数解释:  训练集图片路径,文件夹与类别名的映射文件,设置对图片进行的处理
    image_datasets['train'] = ImageNetTrainDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_train'),
                                           os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data', 'meta.mat'),
                                           data_transforms['train'])
    #参数解释:  验证集图片路径,图片与类别的映射文件, 设置对图片进行的处理
    image_datasets['val'] = ImageNetValDataSet(os.path.join(args.data_dir, 'ILSVRC2012_img_val'),
                                               os.path.join(args.data_dir, 'ILSVRC2012_devkit_t12', 'data','ILSVRC2012_validation_ground_truth.txt'),
                                               data_transforms['val'])

    # wrap your data and label into Tensor
    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=args.batch_size,
                                                 shuffle=True,
                                                 num_workers=args.num_workers) for x in ['train', 'val']}


    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} #返回一个字典!
    return dataloders, dataset_sizes

class ImageNetTrainDataSet(torch.utils.data.Dataset):
    def __init__(self, root_dir, img_label, data_transforms):
        label_array = scio.loadmat(img_label)['synsets']#读取映射文件中的synsets部分
        label_dic = {}
        for i in  range(1000):
            label_dic[label_array[i][0][1][0]] = i#label_array[i][0][1][0]:图像文件夹编号(相当于读入1000个文件夹),和对应的类别,因为共1000个类别
        self.img_path = os.listdir(root_dir)#遍历训练集的文件夹(类别)数
        self.data_transforms = data_transforms
        self.label_dic = label_dic #文件夹和对应的类别组成的字典
        self.root_dir = root_dir
        self.imgs = self._make_dataset()#这里要用self.label_dict

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, item): #Python的魔法方法__getitem__ 可以让对象实现迭代功能
        data, label = self.imgs[item]
        img = Image.open(data).convert('RGB')
        if self.data_transforms is not None:
            try:
                img = self.data_transforms(img)
            except:
                print("Cannot transform image: {}".format(self.img_path[item]))
        return img, label

    def _make_dataset(self):
        class_to_idx = self.label_dic# 文件夹和类别所对应的的类别
        images = []
        dir = os.path.expanduser(self.root_dir)
        for target in sorted(os.listdir(dir)):#target是每一类图像文件夹的名称
            d = os.path.join(dir, target)
            if not os.path.isdir(d):
                continue

            for root, _, fnames in sorted(os.walk(d)):#fnames 是 该类别文件夹下的所有图片
                for fname in sorted(fnames):
                    if self._is_image_file(fname):
                        path = os.path.join(root, fname)#每一张图片的路径
                        item = (path, class_to_idx[target])#每一张图片的路径和它所对应的类别
                        images.append(item)#加入images

        return images

    def _is_image_file(self, filename):
        """Checks if a file is an image.

        Args:
            filename (string): path to a file

        Returns:
            bool: True if the filename ends with a known image extension
        """
        filename_lower = filename.lower()
        return any(filename_lower.endswith(ext) for ext in IMG_EXTENSIONS)

class ImageNetValDataSet(torch.utils.data.Dataset):
    def __init__(self, img_path, img_label, data_transforms):
        self.data_transforms = data_transforms
        img_names = os.listdir(img_path)#获取验证集中所有图片的名称组成img_names(list类型)
        img_names.sort()#对list类型的数据进行排序
        self.img_path = [os.path.join(img_path, img_name) for img_name in img_names]
        with open(img_label,"r") as input_file:
            lines = input_file.readlines()
            self.img_label = [(int(line)-1) for line in lines] #获取label,[1,val_lengths]

    def __len__(self):
        return len(self.img_path)

    def __getitem__(self, item):          #Python的魔法方法__getitem__ 可以让对象实现迭代功能
        img = Image.open(self.img_path[item]).convert('RGB')
        label = self.img_label[item]
        if self.data_transforms is not None:
            try:
                img = self.data_transforms(img)
            except:
                print("Cannot transform image: {}".format(self.img_path[item]))
        return img, label #返回一个tuple数据类型。

这里大家需要了解的是python中的 __getitem__方法的用法。

另外大家疑惑最多的应该是这部分:

label_array = scio.loadmat(img_label)['synsets']#读取映射文件中的synsets部分,这里保存的最重要的信息就是类别和文件夹的对应关系
        label_dic = {}
        for i in  range(1000):
            label_dic[label_array[i][0][1][0]] = i#label_array[i][0][1][0]:图像文件夹编号(相当于读入1000个文件夹),和对应的类别,因为共1000个类别

 其实,这是由.mat文件中的数据类型所决定的,因为 scio.loadmat(img_label) 读出来的是字典型数据,因此我们需要得到 'synsets' 所对应的内容。为了方便大家理解们这里大家们可以将label_array打印出来,查看他的属性(剧透:尺寸是[1860,1]),对应的代码:

import scipy.io as scio

path = './/ImageNet//ILSVRC2012_devkit_t12//data//meta.mat'

result = scio.loadmat(path)

print(type(result))

for i in range(2000):
    print(i)
    print(result['synsets'][i][0][1][0])#是为了获取图片文件夹编号,典型的对数据做切片

到这里,应该就没有问题了。

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值