最近在学pytorch,今天晚上用pytorch的数据加载部分,一开始一直在纠结怎么划分数据集,后来还是手动分了,开始是用torch.utils.data.random_split
但是后来一直报错,我也不知道哪里有错,解决不了,后来暴力解决了
1.重写dataset类,这是必须要写的
- 主要继承Dataset类,重写
__getitem__
,and__len__
的方法 - 我的问题:针对一个文件夹有n张图片,然后一个csv文件中有每个图片对应的label,具体样式如下
步骤1:将image和label对应加载到一个数据集中
class SkinDataset(Dataset):
def __init__(self,csv_file,root_dir,transform=None):
self.csv=pd.read_csv(csv_file)
self.root_dir=root_dir
self.transform=transform
def __len__(self):
return len(self.csv)
def __getitem__(self,idx):
image_path=os.path.join(self.root_dir+self.csv.ix[idx,0]+'.jpg')