数据集下载
数据集下载:https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 密码: hpn4
一共包含12500张狗的照片,12500张猫的照片
数据处理
原始数据train文件家里包含所有的图片,首先对其进行处理,生成一个图片名称与标签相对应的txt文件,好进行索引。将猫的标签对应为0,狗的标签对应为1
import os
def text_save(filename,data_dir,data_class):
file = open(filename,'a')
for i in range(len(data_class)):
s = str(data_dir[i]+' '+str(data_class[i])) +'\n'
file.write(s)
file.close()
print('文件保存成功')
def get_files(file_dir):
#file_dir 文件路径
cat = []
dog = []
label_dog = []
label_cat = []
for file in os.listdir(file_dir):
name = file.split(sep = '.')
if name[0]=='cat':
cat.append(file_dir + file)
label_cat.append(0)#0对应猫
else:
dog.append(file_dir + file)
label_dog.append(1)
print('There are %d cats and %d dogs' %(len(cat), (len(dog))))
cat.extend(dog)
label_cat.extend(label_dog)
image_list = cat
label_list = label_cat
print(type(image_list))
return image_list,label_list
def data_process():#生成train.txt,包含图片名称一级标签
image_list, label_list = get_files('train/')
text_save('train.txt', image_list, label_list)
加载训练数据
#重写dataset类,用于加载dataloader
class train_Dataset(Dataset):
def __init__(self, txt_path, transform=None, target_transform=None):
fh = open(txt_path, 'r')
imgs = []
for line in fh:
line = line.rstrip()
words = line.split()
imgs.append((words[0], int(words[1])))
self.imgs = imgs # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
self.transform = transform
self.target_transform = target_transform
def __getitem__(self, index):
fn, label = self.imgs[index]
img = Image.open(fn).convert('RGB') # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label
def __len__(self):
return len(self.imgs)
训练函数
def save_models(net,epoch):#模型保存函数,自己更改位置
torch.save(net.state_dict(),'/home/cat/mymodel_epoch_1{}.pth'.format(epoc