Pytorch实战:Kaggle下基于Resnet猫狗识别具体实现代码

该博客介绍了使用PyTorch在Kaggle上进行猫狗识别的实战过程,包括数据集下载、数据处理、训练数据加载、训练函数的编写以及训练后的测试。作者分享了完整的代码,并提供了在Linux环境下运行的命令。文章还揭示了一个关键的LogLoss优化技巧:通过调整预测概率(如狗为0.995,猫为0.005),可以显著提高评估分数。
摘要由CSDN通过智能技术生成

Pytorch实战:Kaggle下基于Resnet猫狗识别具体实现代码

数据集下载

数据集下载:https://pan.baidu.com/s/1SlNAPf3NbgPyf93XluM7Fg 密码: hpn4
一共包含12500张狗的照片,12500张猫的照片

数据处理

原始数据train文件家里包含所有的图片,首先对其进行处理,生成一个图片名称与标签相对应的txt文件,好进行索引。将猫的标签对应为0,狗的标签对应为1

import os
def text_save(filename,data_dir,data_class):
    file = open(filename,'a')
    for i in range(len(data_class)):
        s = str(data_dir[i]+' '+str(data_class[i])) +'\n'
        file.write(s)
    file.close()
    print('文件保存成功')

def get_files(file_dir):
    #file_dir 文件路径
    cat = []
    dog = []
    label_dog = []
    label_cat = []
    for file in os.listdir(file_dir):
        name = file.split(sep = '.')
        if name[0]=='cat':
            cat.append(file_dir + file)
            label_cat.append(0)#0对应猫
        else:
            dog.append(file_dir + file)
            label_dog.append(1)
    print('There are %d cats and %d dogs' %(len(cat), (len(dog))))

    cat.extend(dog)
    label_cat.extend(label_dog)
    image_list = cat
    label_list = label_cat
    print(type(image_list))
    return image_list,label_list

def data_process():#生成train.txt,包含图片名称一级标签
    image_list, label_list = get_files('train/')
    text_save('train.txt', image_list, label_list)

加载训练数据

#重写dataset类,用于加载dataloader
class train_Dataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r')
        imgs = []
        for line in fh:
            line = line.rstrip()
            words = line.split()
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1
        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等
        return img, label

    def __len__(self):
        return len(self.imgs)

训练函数

def save_models(net,epoch):#模型保存函数,自己更改位置
    torch.save(net.state_dict(),'/home/cat/mymodel_epoch_1{}.pth'.format(epoc
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值