pytorch dataloader 和 dataset 数据加载的研究

一 pytorch 数据加载的研究

一、dataloader and dataset?

Dataset抽象类,所有自定义的Dataset都需要继承它,并且必须复写__getitem__()这个类方法。

DataLoader(): 迭代器, 我们在训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。

二、类的实例化

大多数文章,并没有仔细探究Dataset这个类,究竟是怎么一步步完成数据和标签的加载的

first of all ,它是个类
所以,从类的角度,继承,重写,实例化,这个面向对象的思路,先研究一下

1.继承Dataset

代码如下(示例):xxx代表可以自己定义的内容

class myDataset(Dataset):
    def __init__(self, xxx):
     
    def __getitem__(self,index):
            return xxx,xxx
            
    def __len__(self):
        return len(xxx)

可见,getitem 和 len 需要自己重写,并返回一些东西

2.重写父类函数

这里,采用了kaggle 的Dog Breed Identification项目的数据
是个分类任务,使用resnet vgg 就可以解决
数据集包含 3个文件
train(文件夹)
test(文件夹)
label.csv
在这里插入图片描述
可以到官网看 https://www.kaggle.com/competitions/dog-breed-identification/
代码如下(示例):

from torch.utils.data import Dataset
import pandas as pd
import cv2
class myDataset(Dataset):
    def __init__(self, dogdir):
        self.imgset =  dogdir["id"]
        self.labelset = dogdir["breed"]
        dog_breeds = sorted(list(set(self.labelset )))
        n_classes = len(dog_breeds)
        self.class_to_num = dict(zip(dog_breeds, range(n_classes)))
        
    def __getitem__(self,index):
            imgpath = "train/"+self.imgset[index] + ".jpg"
            img = cv2.imread(imgpath)
            labelname  = self.labelset[index]
            labelhot =  self.class_to_num.get(labelname)
            return img, labelhot
    
    def __len__(self):
        return len(self.imgset)

3.实例化

看看继承后的dataset

df = pd.read_csv('labels.csv')  #使用pandas读取csv  
myd = myDataset(df)
img,label = myd.__getitem__(4) #指定4这个item
lenth = myd.__len__()
#print(img)
print(label)
print(lenth)

49 #one hot 编码后的标签
10222 # 总体数量

总结

提示:这里对文章进行总结:

例如:以上就是今天要讲的内容

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值