Pytorch数据集的读取

知乎传送门

class Mnist_data(Dataset):
    def __init__(self,
                 root_dir,
                 pre_load=False,
                 transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.filenames = [] #用于存储文件名和label键值对
        self.images = None
        self.labels = None #下文用到
        
        for j in range(10):
            all_files = glob.glob(opj(self.root_dir,str(j),"*.png"))
            # glob return 文件名的地址
            
            for file in all_files:
                self.filenames.append((file, j))
        self.len = len(self.filenames)
        if pre_load:
            self._pre_load()
        
    def _pre_load(self):  #用于将数据载到内存
        self.images = []
        self.labels = []
        
        for file, label in self.filenames:
            img = Image.open(file)
            self.images.append(img.copy())
            img.close()
            self.labels.append(label)
        
    def __getitem__(self,index):
        if self.images is not None:        #缓存的不为空
            img = self.images[index]
            label = self.labels[index]
        else:
            img_path, label = self.filenames[index]
            img = Image.open(img_path)
            
        if self.transform is not None:
            img = self.transform(img)
        return img, label
        
    def __len__(self):
        return self.len

看的cs231n 2018版的 pytorch 教程,已经更到了0.4版本,这是比较标准的数据集读取方法了.

cs231n有空会上传代码,assignment2做完了结果代码误删了,准备下载2018版的作业重来一遍

 

下面的是另一版本的读取方式,对比之下发现上面的官方教程多了一步缓存数据,更快.但是他添加了数据集分割的功能,可以综合一下.

class DogCat(data.Dataset):
    
    def __init__(self,root,transforms=None,train=True,test=False):
        '''
        主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据
        '''
        self.test = test
        imgs = [os.path.join(root,img) for img in os.listdir(root)] 

        # test1: data/test1/8973.jpg
        # train: data/train/cat.10004.jpg 
        if self.test:
            imgs = sorted(imgs,key=lambda x:int(x.split('.')[-2].split('/')[-1]))
        else:
            imgs = sorted(imgs,key=lambda x:int(x.split('.')[-2]))
            
        imgs_num = len(imgs)

        if self.test:
            self.imgs = imgs
        elif train:
            self.imgs = imgs[:int(0.7*imgs_num)]
        else :
            self.imgs = imgs[int(0.7*imgs_num):]
            
    
        if transforms is None:
            normalize = T.Normalize(mean = [0.485, 0.456, 0.406], 
                                     std = [0.229, 0.224, 0.225])

            if self.test or not train: 
                self.transforms = T.Compose([
                    T.Scale(224),
                    T.CenterCrop(224),
                    T.ToTensor(),
                    normalize
                    ]) 
            else :
                self.transforms = T.Compose([
                    T.Scale(256),
                    T.RandomSizedCrop(224),
                    T.RandomHorizontalFlip(),
                    T.ToTensor(),
                    normalize
                    ]) 
                
        
    def __getitem__(self,index):
        '''
        一次返回一张图片的数据
        '''
        img_path = self.imgs[index]
        if self.test: label = int(self.imgs[index].split('.')[-2].split('/')[-1])
        else: label = 1 if 'dog' in img_path.split('/')[-1] else 0
        data = Image.open(img_path)
        data = self.transforms(data)
        return data, label
    
    def __len__(self):
        return len(self.imgs)

 

此外读取图片进行分类最快捷的是imagefolder方法

import torch
from torchvision import transforms, datasets

data_transform = transforms.Compose([
        transforms.RandomSizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
hymenoptera_dataset = datasets.ImageFolder(root='hymenoptera_data/train',
                                           transform=data_transform)
dataset_loader = torch.utils.data.DataLoader(hymenoptera_dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=4)

 

下面的是可视化的代码

trainset = Mnist_data('mnist_png/training', 
                      pre_load=True, 
                      transform=transforms.ToTensor())
testset = Mnist_data('mnist_png/testing',
                      pre_load=True, 
                      transform=transforms.ToTensor())
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
test_loader =  DataLoader(testset, batch_size=64, shuffle=False, num_workers=4)


from torchvision.utils import make_grid
# functions to show an image
def imshow(img):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()

# show images
imshow(make_grid(images))

 

 

另一个比较高效的数据集处理分割方式来自fastai论坛,只要稍微改一下基本就能应用于所有的数据集的清理.

import os
import random
import shutil

def organize_folder(folder):
    _, _, filenames = next(os.walk(folder))
    unique_classes = {filename.split(".")[0] for filename in filenames}
    for _class in unique_classes:
        path = os.path.join(folder, _class)
        if not os.path.exists(path):
            os.makedirs(path)
        for filename in filenames:
            if filename.startswith(_class):
                shutil.move(os.path.join(folder, filename), os.path.join(path, filename))        
    
def create_sample_folder(_from, to, percentage=0.1, move=True):
    if not os.path.exists(to):
        os.makedirs(to)
    _, folders, _ = next(os.walk(_from))
    for folder in folders:
        if not os.path.exists(os.path.join(to, folder)):
            os.makedirs(os.path.join(to, folder))
        _, _, files = next(os.walk(os.path.join(_from, folder)))
        sample = random.sample(files, int(len(files) * percentage))
        for filename in sample:
            if move:
                shutil.move(os.path.join(_from, folder, filename), os.path.join(to, folder, filename))
            else:
                shutil.copyfile(os.path.join(_from, folder, filename), os.path.join(to, folder, filename))

 

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值