[Pytorch]PyTorch Dataloader自定义数据读取

[Pytorch]PyTorch Dataloader自定义数据读取

整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变。

所有图片都在一个文件夹1

之前刚开始用的时候,写Dataloader遇到不少坑。网上有一些教程 分为all images in one folder 和 each class one folder。后面的那种写的人比较多,我写一下前面的这种,程式化的东西,每次不同的任务改几个参数就好。

等训练的时候写一篇文章把2333


一.已有的东西

举例子:用kaggle上的一个dog breed的数据集为例。数据文件夹里面有三个子目录

test: 几千张图片,没有标签,测试集

train: 10222张狗的图片,全是jpg,大小不一,有长有宽,基本都在400×300以上

labels.csv : excel表格, 图片名称+品种名称

<img src="https://pic4.zhimg.com/v2-6128d817c09f05fe3bdfe05e1f84a92f_b.jpg" data-caption="" data-size="normal" data-rawwidth="496" data-rawheight="85" class="origin_image zh-lightbox-thumb" width="496" data-original="https://pic4.zhimg.com/v2-6128d817c09f05fe3bdfe05e1f84a92f_r.jpg"> v2-6128d817c09f05fe3bdfe05e1f84a92f_hd.jpg

我喜欢先用pandas把表格信息读出来看一看

import pandas as pd
import numpy as np
df = pd.read_csv('./dog_breed/labels.csv')
print(df.info())
print(df.head())
<img src="https://pic1.zhimg.com/v2-9d680235e5ff00c3f869a3eab4630ca4_b.jpg" data-caption="" data-size="normal" data-rawwidth="731" data-rawheight="265" class="origin_image zh-lightbox-thumb" width="731" data-original="https://pic1.zhimg.com/v2-9d680235e5ff00c3f869a3eab4630ca4_r.jpg"> v2-9d680235e5ff00c3f869a3eab4630ca4_hd.jpg

看到,一共有10222个数据,id对应的是图片的名字,但是没有后缀 .jpg。 breed对应的是犬种。


二.预处理

我们要做的事情是:

1)得到一个长 list1 : 里面是每张图片的路径

2)另外一个长list2: 里面是每张图片对应的标签(整数),顺序要和list1对应。

3)把这两个list切分出来一部分作为验证集


1)看看一共多少个breed,把每种breed名称和一个数字编号对应起来:

from pandas import Series,DataFrame

breed = df['breed']
breed_np = Series.as_matrix(breed)
print(type(breed_np) )
print(breed_np.shape)   #(10222,)

#看一下一共多少不同种类
breed_set = set(breed_np)
print(len(breed_set))   #120

#构建一个编号与名称对应的字典,以后输出的数字要变成名字的时候用:
breed_120_list = list(breed_set)
dic = {}
for i in range(120):
    dic[  breed_120_list[i]   ] = i

2)处理id那一列,分割成两段:

file =  Series.as_matrix(df["id"])
print(file.shape)

import os
file = [i+".jpg" for i in file]
file = [os.path.join("./dog_breed/train",i) for i in file ]
file_train = file[:8000]
file_test = file[8000:]
print(file_train)

np.save( "file_train.npy" ,file_train )
np.save( "file_test.npy" ,file_test )

里面就是图片的路径了

<img src="https://pic3.zhimg.com/v2-b740e480301df1fded91c92090065736_b.jpg" data-caption="" data-size="normal" data-rawwidth="1076" data-rawheight="113" class="origin_image zh-lightbox-thumb" width="1076" data-original="https://pic3.zhimg.com/v2-b740e480301df1fded91c92090065736_r.jpg"> v2-b740e480301df1fded91c92090065736_hd.jpg

3)处理breed那一列,分成两段:

breed = Series.as_matrix(df["breed"])
print(breed.shape)
number = []
for i in range(10222):
    number.append(  dic[ breed[i] ]  )
number = np.array(number) 
number_train = number[:8000]
number_test = number[8000:]
np.save( "number_train.npy" ,number_train )
np.save( "number_test.npy" ,number_test )

三.Dataloader

我们已经有了图片路径的list,target编号的list。填到Dataset类里面就行了。

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

normalize = transforms.Normalize(
    mean=[0.485, 0.456, 0.406],
    std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
    #transforms.Scale(256),
    #transforms.CenterCrop(224),
    transforms.ToTensor(),
    normalize
])

def default_loader(path):
    img_pil =  Image.open(path)
    img_pil = img_pil.resize((224,224))
    img_tensor = preprocess(img_pil)
    return img_tensor

#当然出来的时候已经全都变成了tensor
class trainset(Dataset):
    def __init__(self, loader=default_loader):
        #定
  • 14
    点赞
  • 97
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值