代码部分采用jupyter notebook编写
首先建立了cat 和dog两个文件夹,从网上随便一类找了一张图片放进去了
代码主要如下
主要思路是建立一个csv文件,每一类图片通过简单的数据增强(本文选了旋转)将猫狗图片分别增加到十张,然后将猫的图片赋予标签0,狗的图片赋予标签1,写入csv文件中,然后制作自定义的DataLoader,主要是写好init ,getitem以及len三个类,其实就是将image和label的单个列表逐一读取,后续再利用dataloader实现迭代的数据循环,最后测试了 一下dataloader里的图片和标签正不正确。
csv_dir=os.path.join('E:\教程\CSDNlabel.csv','label'+'.csv')
dir1='E:\教程\CSDN'
dir2=os.listdir(dir1)
def showpath():
for i in dir2:
dir3=os.path.join(dir1,i)
dir4=os.listdir(dir3)
for j in range(len(dir4)):
dir5=os.path.join(dir3,dir4[j])
# print(dir5)
img=Image.open(dir5)
if i =='cat':
for k in range(10):
rotation_img=img.rotate(20*k)
dir6=os.path.join(dir3,('cat'+str(k)+'.jpg'))
# print(dir6)
rotation_img.save(dir6)
if i =='dog':
for k in range(10):
rotation_img=img.rotate(20*k)
dir6=os.path.join(dir3,(i+str(k)+'.jpg'))
rotation_img.save(dir6)
with open(csv_dir,'w',newline='') as f:
for j in range(len(dir2)):
newdir0=os.path.join(dir1,dir2[j])
newdir1=os.listdir(newdir0)
if j==0:
for k in range(len(newdir1)):
dict1={}
newdir2=os.path.join(newdir0,newdir1[k])
dict1[newdir2]='0'
# print(dict1)
example=[]
for m in dict1:
example.append(m)
example.append(dict1[m])
writer=csv.writer(f)
writer.writerow(example)
if j==1:
for k in range(len(newdir1)):
dict1={}
newdir2=os.path.join(newdir0,newdir1[k])
dict1[newdir2]='1'
# print(dict1)
example=[]
for m in dict1:
example.append(m)
example.append(dict1[m])
writer=csv.writer(f)
writer.writerow(example)
def default_loader(path):
return Image.open(path)
class Dataset():
def __init__(self,loader=default_loader,transform=None):
with open(csv_dir,'r') as f:
imgs=[]
for line in f:
line=line.strip('\n')
# line=line.rstrip('\n')
line=line.split(',')
imgs.append((line[0],int(line[1])))
self.imgs=imgs
self.loader=loader
self.transform=transform
def __len__(self):
return len(self.imgs)
def __getitem__(self,index):
images,labels=self.imgs[index]
img=self.loader(images)
img=self.transform(img)
return img,labels
train_transform=transforms.Compose([transforms.Resize(280),transforms.CenterCrop(256),transforms.ToTensor(),transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
trainset=Dataset(transform=train_transform)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=2,shuffle=True)
showpath()
import torchvision
import matplotlib.pyplot as plt
image,labels=next(iter(trainloader))
image=torchvision.utils.make_grid(image)
image=image.numpy().transpose(1,2,0)
print([int(labels[i].numpy()) for i,label in enumerate(labels)])
plt.imshow(image)
plt.show()