我们经常会遇到使用非大众数据集(例如imagenet,cifar10)的情况,因此需要自己实现一个dataset类。
nips17数据
nips17:
数据标签,我们仅考虑true label和target class
数据形式
提取标签信息
csv_note_dir="../data_nips2017/dev_dataset.csv"
# load notes
import pandas as pd
df = pd.read_csv(csv_note_dir)
image_id=np.array(df.iloc[:,0])
true_label=np.array(df.iloc[:,6])-1
target_label=np.array(df.iloc[:,7])-1
定义loader函数
注意这里的load函数返回最好tensor类型,因为运行到时候会有个torch.as_tensor()函数,如果数据类型不对会报错(我尝试了int和numpy都会报,也有可能是版本问题,没有过多追究)
def default_loader(path): #tensor
img_pil = Image.open(path)
img_tensor = transform(img_pil)
return img_tensor
定义数据集类
自定义类继承于torch.utils.data.Dataset, 我们仅需要改构造函数,getitem和len函数即可
from torch.utils.data import Dataset
class custom_dataset(Dataset):
def __init__(self, root, img_id,true_label, target_label, loader):
#image即图片路径,通过对每个图片定义路径来调用它
self.images = root+"/"+img_id+".png"
self.true_label = true_label
self.target = target_label
self.loader = loader
def __getitem__(self, index): #定义 getitem
fn = self.images[index]
img = self.loader(fn)
target = self.target[index]
label = self.true_label[index]
return img,label,target
def __len__(self): #定义len
return len(self.images)
代码中调用
normal_data=custom_dataset(data_dir,image_id,true_label,target_label,default_loader)
normal_loader = torch.utils.data.DataLoader(normal_data, batch_size=batch_size, shuffle=False)
normal_iter = iter(normal_loader)
for i in range(datasize/batch_size):
images, labels,targets = normal_iter.next()