custom dataset implementation 自定义数据集,与dataloader的使用

我们经常会遇到使用非大众数据集(例如imagenet,cifar10)的情况,因此需要自己实现一个dataset类。

nips17数据

nips17:
数据标签,我们仅考虑true label和target class
在这里插入图片描述
数据形式
在这里插入图片描述

提取标签信息

    csv_note_dir="../data_nips2017/dev_dataset.csv"

    # load notes
    import pandas as pd
    df = pd.read_csv(csv_note_dir)
    image_id=np.array(df.iloc[:,0])
    true_label=np.array(df.iloc[:,6])-1
    target_label=np.array(df.iloc[:,7])-1

定义loader函数

注意这里的load函数返回最好tensor类型,因为运行到时候会有个torch.as_tensor()函数,如果数据类型不对会报错(我尝试了int和numpy都会报,也有可能是版本问题,没有过多追究)

def default_loader(path): #tensor
        img_pil =  Image.open(path)
        img_tensor = transform(img_pil)
        return img_tensor

定义数据集类

自定义类继承于torch.utils.data.Dataset, 我们仅需要改构造函数,getitem和len函数即可

from torch.utils.data import Dataset

class custom_dataset(Dataset):
    def __init__(self, root, img_id,true_label, target_label, loader):
        #image即图片路径,通过对每个图片定义路径来调用它
        self.images = root+"/"+img_id+".png"
        self.true_label = true_label
        self.target = target_label
        self.loader = loader

    def __getitem__(self, index): #定义 getitem
        fn = self.images[index]
        img = self.loader(fn)
        target = self.target[index]
        label = self.true_label[index]
        return img,label,target

    def __len__(self): #定义len
        return len(self.images)

代码中调用

normal_data=custom_dataset(data_dir,image_id,true_label,target_label,default_loader)
normal_loader = torch.utils.data.DataLoader(normal_data, batch_size=batch_size, shuffle=False)
normal_iter = iter(normal_loader)
for i in range(datasize/batch_size):
   images, labels,targets = normal_iter.next()  

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值