使用VGG迁移学习开启《猫狗大战挑战赛》

一、前言

猫狗大战挑战由Kaggle于2013年举办的,目前比赛已经结束,不过仍然可以把AI研习社猫狗大战赛平台作为练习赛每天提交测试结果,该平台数据集包含猫狗图片共24000张,没有任何标注数据,选手需要训练模型正确识别猫狗图片,1= dog,0 = cat。这里使用在 ImageNet 上预训练的 VGG 网络模型进行测试,因为原网络的分类结果是1000类,所以要进行迁移学习,对原网络进行 fine-tune (即固定前面若干层,作为特征提取器,只重新训练最后两层),并把测试结果提交到该平台。那么,现在就让我们开始吧。

二、加载数据集

前期如何把解压后的竞赛数据集放到colab上着实耗费了我大量的时间,我认为非常有必要把这个单独作为一章讲一下。如果你本地有很强的GPU,不需要在colab上跑代码,这章节可以忽略,由于我的电脑跑不动这么多数据,GPU也不行,所以只能在colab上运行。在这个过程中许多问题本是可以避免的,由于对一些操作和指令不熟练,导致许多时间白白流失,即打消了初学者的自信心,也拖慢了实验的进度,究其原因,主要有以下几点:

  1. 在google drive上传和解压数据集时间特别慢,需要数十个小时
  2. colab运行时间有时限,长时间不操作(大概20分钟左右)会导致当前训练的数据被回收
  3. 猫狗大战数据集是没有标签的,需要自己定义Dataset类加载数据

现在就来一个个解决上面的几个痛点吧!

(1)colab上传和解压大数据集

我们的目的是要在colab上读取竞赛数据集的图片,达到目的的方式有三个:

  • 方式一:把数据集压缩包上传到google drive,在drive上解压
  • 方式二:数据集解压后再上传到google drive
  • 方式三:把数据集压缩包上传到google drive,在colab连接的虚拟机上解压

上面几种方式哪个好呢?我先不直接说结果,来实验下吧!

首先,采用方式一,把数据集压缩包上传到google drive,在drive上解压,操作很简单,在google drive上右键上传竞赛数据集cat_dog.rar,文件大小521MB,上传时间二十多分钟,上传完毕后,再drive上解压,现在痛点来了,时间竟然要十几个小时,具体操作如下:

  • 打开colab,挂载google drive,方法可以参考我的博客Google Colab挂载drive上的数据文件

  • 解压drive上的cat_dog.rar文件,命令为

    ! apt-get install rar
    !unrar x "/content/drive/Colab/人工智能课/cat_dog.rar" "/content/drive/Colab/人工智能课/"
    

    解压过程如下:

我大致算了一下,每张图片解压时间5秒钟左右,24000张图片要大约33小时啊!!!所以,这种方式直接pass掉。

再来看,方式二,把数据集解压后再上传到google drive,解压后的数据集文件夹大小虽然只有五百多兆,但上传速度特别慢,大概要5至7个小时,并且一旦中间断网或是网络不稳定,极有可能导致数据损坏。我就是花费了大半天时间把所有解压后的文件上传完了,由于中间网络不稳定,导致数据读取不正确,最终这种方式也放弃了,哎,说多了都是泪!

最后,就只有方式三了,把数据集压缩包上传到google drive,在colab连接的虚拟机上解压文件,方法是:

  • 将google drive上数据集文件cat_dog.rar拷贝到colab连接的虚拟机上

    !cp -i /content/drive/Colab/人工智能课/cat_dog.rar /content/
    
  • 在虚拟机上解压压缩文件:

    ! apt-get install rar
    ! unrar x cat_dog.rar
    

    运行过程如下:

这种方式速度非常快,如果操作正确,解压时间仅有一分钟左右,非常值得推荐!

(2)阻止Colab自动掉线

在colab上训练代码,页面隔一段时间无操作之后就会自动掉线,之前训练的数据都会丢失。现在你体会到我之前连续几个小时在google drive解压数据集文件的艰辛路程了吧。不过好在最后终于找到了一种可以让其自动保持不离线的方法,用一个js程序自动点击连接按钮。代码如下:

function ClickConnect(){
   
  console.log("Working"); 
  document
    .querySelector("#top-toolbar > colab-connect-button")
    .shadowRoot
    .querySelector("#connect")
    .click()
}
 
setInterval(ClickConnect,60000)

使用方式是:按快捷键ctrl+shift+i,并选择Console,然后复制粘贴上面的代码,并点击回车,该程序便可以运行了,如下所示:

(3)猫狗大战数据集是没有标签的,需要自己定义Dataset类才能加载数据

猫狗大战数据集是没有标签的,但是从其训练集和验证集的图片名字可以获取标签,这就需要我们自己定义Dataset类了,由于这个部分篇幅较多,我们放在下一章讲吧。

三、数据预处理

传统的mnist数据集是集成到torchvision.datasets,我们使用datasets.MNIST就可以方便加载数据,不用做过多的其它处理,而猫狗大战竞赛数据集是如下图方式,并没有用标签对文件夹分类存放,所以我们需要通过图片名称获取标签,并自定义Dataset类加载图片。

我定义的Dataset类如下所示:

from torch.utils.data import Dataset,DataLoader
# 创建自己的类:MyDataset,继承 Dataset 类
class MyDataset(Dataset):
    def __init__(self, txt, data_path=None, transform=None, target_transform=None, loader=default_loader):
        super(MyDataset, self).__init__()
        file_path = data_path + txt
        file = open(file_path, 'r', encoding='utf8')
        imgs = []
        for line in file:
            line = line.split()
            imgs.append((line[0],line[1].rstrip('\n')))

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader
        self.data_path = data_path

    # 可以通过索引进行条用,如data[1]
    def __getitem__(self, index):
        # 按照索引读取每个元素的具体内容
        imgName, label = self.imgs[index]
        # imgPath = self.data_path + imgName
        imgPath = imgName
        # 调用那张图片读哪张,最大限度发挥GPU显存
        img = self.loader(imgPath)
        if self.transform is not None:
            img = self.transform(img)
            label = torch.from_numpy(np.array(int(label)))
        return img, label

    def __len__(self):
        # 数据集的图片数量
        
  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

雷恩Layne

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值