一、前言
猫狗大战挑战由Kaggle于2013年举办的,目前比赛已经结束,不过仍然可以把AI研习社猫狗大战赛平台作为练习赛每天提交测试结果,该平台数据集包含猫狗图片共24000张,没有任何标注数据,选手需要训练模型正确识别猫狗图片,1= dog,0 = cat。这里使用在 ImageNet 上预训练的 VGG 网络模型进行测试,因为原网络的分类结果是1000类,所以要进行迁移学习,对原网络进行 fine-tune (即固定前面若干层,作为特征提取器,只重新训练最后两层),并把测试结果提交到该平台。那么,现在就让我们开始吧。
二、加载数据集
前期如何把解压后的竞赛数据集放到colab上着实耗费了我大量的时间,我认为非常有必要把这个单独作为一章讲一下。如果你本地有很强的GPU,不需要在colab上跑代码,这章节可以忽略,由于我的电脑跑不动这么多数据,GPU也不行,所以只能在colab上运行。在这个过程中许多问题本是可以避免的,由于对一些操作和指令不熟练,导致许多时间白白流失,即打消了初学者的自信心,也拖慢了实验的进度,究其原因,主要有以下几点:
- 在google drive上传和解压数据集时间特别慢,需要数十个小时
- colab运行时间有时限,长时间不操作(大概20分钟左右)会导致当前训练的数据被回收
- 猫狗大战数据集是没有标签的,需要自己定义Dataset类加载数据
现在就来一个个解决上面的几个痛点吧!
(1)colab上传和解压大数据集
我们的目的是要在colab上读取竞赛数据集的图片,达到目的的方式有三个:
- 方式一:把数据集压缩包上传到google drive,在drive上解压
- 方式二:数据集解压后再上传到google drive
- 方式三:把数据集压缩包上传到google drive,在colab连接的虚拟机上解压
上面几种方式哪个好呢?我先不直接说结果,来实验下吧!
首先,采用方式一,把数据集压缩包上传到google drive,在drive上解压,操作很简单,在google drive上右键上传竞赛数据集cat_dog.rar,文件大小521MB,上传时间二十多分钟,上传完毕后,再drive上解压,现在痛点来了,时间竟然要十几个小时,具体操作如下:
-
打开colab,挂载google drive,方法可以参考我的博客Google Colab挂载drive上的数据文件。
-
解压drive上的cat_dog.rar文件,命令为
! apt-get install rar !unrar x "/content/drive/Colab/人工智能课/cat_dog.rar" "/content/drive/Colab/人工智能课/"
解压过程如下:
我大致算了一下,每张图片解压时间5秒钟左右,24000张图片要大约33小时啊!!!所以,这种方式直接pass掉。
再来看,方式二,把数据集解压后再上传到google drive,解压后的数据集文件夹大小虽然只有五百多兆,但上传速度特别慢,大概要5至7个小时,并且一旦中间断网或是网络不稳定,极有可能导致数据损坏。我就是花费了大半天时间把所有解压后的文件上传完了,由于中间网络不稳定,导致数据读取不正确,最终这种方式也放弃了,哎,说多了都是泪!
最后,就只有方式三了,把数据集压缩包上传到google drive,在colab连接的虚拟机上解压文件,方法是:
-
将google drive上数据集文件cat_dog.rar拷贝到colab连接的虚拟机上
!cp -i /content/drive/Colab/人工智能课/cat_dog.rar /content/
-
在虚拟机上解压压缩文件:
! apt-get install rar ! unrar x cat_dog.rar
运行过程如下:
这种方式速度非常快,如果操作正确,解压时间仅有一分钟左右,非常值得推荐!
(2)阻止Colab自动掉线
在colab上训练代码,页面隔一段时间无操作之后就会自动掉线,之前训练的数据都会丢失。现在你体会到我之前连续几个小时在google drive解压数据集文件的艰辛路程了吧。不过好在最后终于找到了一种可以让其自动保持不离线的方法,用一个js程序自动点击连接按钮。代码如下:
function ClickConnect(){
console.log("Working");
document
.querySelector("#top-toolbar > colab-connect-button")
.shadowRoot
.querySelector("#connect")
.click()
}
setInterval(ClickConnect,60000)
使用方式是:按快捷键ctrl+shift+i
,并选择Console
,然后复制粘贴上面的代码,并点击回车,该程序便可以运行了,如下所示:
(3)猫狗大战数据集是没有标签的,需要自己定义Dataset类才能加载数据
猫狗大战数据集是没有标签的,但是从其训练集和验证集的图片名字可以获取标签,这就需要我们自己定义Dataset类了,由于这个部分篇幅较多,我们放在下一章讲吧。
三、数据预处理
传统的mnist数据集是集成到torchvision.datasets
,我们使用datasets.MNIST
就可以方便加载数据,不用做过多的其它处理,而猫狗大战竞赛数据集是如下图方式,并没有用标签对文件夹分类存放,所以我们需要通过图片名称获取标签,并自定义Dataset类加载图片。
我定义的Dataset类如下所示:
from torch.utils.data import Dataset,DataLoader
# 创建自己的类:MyDataset,继承 Dataset 类
class MyDataset(Dataset):
def __init__(self, txt, data_path=None, transform=None, target_transform=None, loader=default_loader):
super(MyDataset, self).__init__()
file_path = data_path + txt
file = open(file_path, 'r', encoding='utf8')
imgs = []
for line in file:
line = line.split()
imgs.append((line[0],line[1].rstrip('\n')))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
self.data_path = data_path
# 可以通过索引进行条用,如data[1]
def __getitem__(self, index):
# 按照索引读取每个元素的具体内容
imgName, label = self.imgs[index]
# imgPath = self.data_path + imgName
imgPath = imgName
# 调用那张图片读哪张,最大限度发挥GPU显存
img = self.loader(imgPath)
if self.transform is not None:
img = self.transform(img)
label = torch.from_numpy(np.array(int(label)))
return img, label
def __len__(self):
# 数据集的图片数量