解决上一次九分类存在的问题
1.将所有图片名和图片的类别存在一个csv文件中,name对应图片的名字,label对应图片的标签。
def generate_csv(path,type,csv_path):
with open(csv_path,'w',newline='') as csvfile:
svwriter=csv.writer(csvfile,dialect="excel")
svwriter.writerow(['name', 'label'])
listdir=os.listdir(os.path.join(path,type))
for i in listdir:
classimage=os.path.join(path,type,i)#0,1,2,3,4
data=os.listdir(classimage)
for j in data:
j=j.split('.')[0]
svwriter.writerow([j,i])
2.读取csv文件中对应的图片。由于csv文件中保存的是图片名称和图片的label,我们需要读取到具体的图片,此时我们需要重写dataset方法,将图片名和label放在字典中,由图片地址获得图片名,再根据图片名去字典中找对应图片的label,返回图片名以及label:
class mydataset(Dataset):
def __init__(self,data_folder,class_dict,transform=None):
self.data_folder=data_folder#存放图片的文件夹train,test,val
self.class_dict=class_dict #由csv得到的名,类别字典
self.transform=transform
self.imageclass=[s for s in os.listdir(data_folder)]
self.imagelist=[]
for i in self.imageclass:
self.path=os.path.join(data_folder,i)
self.ima=os.listdir(self.path)
for er in self.ima:
er=i+'/'+er
self.imagelist.append(er)
self.labels=[class_dict[i.split('.')[0].split('/')[1]] for i in self.imagelist]
def __len__(self):
return len(self.imagelist)
def __getitem__(self, idx):
image_path=os.path.join(self.data_folder,self.imagelist[idx])
img=Image.open(image_path)
imag=self.transform(img)
img_name=self.imagelist[idx].split('.')[0].split('/')[1]
label=self.class_dict[img_name]
return img_name,imag,label
def init()方法用来初始化,我们必须实现def len()和def__getitem__()方法,getitem()方法每次返回一个图片,len()是整个的长度,需要调用getitem()多少次
由于class mydataset(Dataset)每次只能返回一个图片,因此我们需要使用dataloader来一次加载多张图片。
def data_prepare(path,type,batch_size,transform):
csv_data=pd.read_csv(path+'/'+type+'.csv')
class_dict=