今天学习了油管up主KevinRush的教程Cats vs. Dogs(只找到油管地址),还找到了Maples丶丶的博客,感谢这位大神,写的很详细,这里参考大神的博客记录一下学习笔记。
数据集可以从Kaggle官网上下载:https://www.kaggle.com/c/dogs-vs-cats
代码见:http://blog.csdn.net/c20081052/article/details/76376688
工程分为input_data.py,model.py和training.py三部分。其中重点介绍第一部分。
1. input_data.py
分为三部分:获取数据集划分标签get_files()
,分批次get_batch()
和测试Test。
1.1 get_files()
Kaggle提供的数据集包含了猫和狗图片各12500幅,都是以cat.<数字>.jpg或dog.<数字>.jpg命名,因此可以根据文件名分类打标签。get_files()
就是用于读取数据集,根据文件名,对数据集打标签,以列表形式返回图片和标签。
主要流程:
1.读取数据集,根据文件名,分成cat和dog两类图片和标签。这里cat和dog各有12500幅图片。
2.使用np.hstack()
将cat和dog的图片和标签整合为列表image_list
和label_list
,image_list
和label_list
的大小均为25000。
3.将image_list
和label_list
合并,存放在temp
中,此时temp
的大小为2x25000。对temp
进行转置,temp
的大小变为25000x2。
4.使用np.random.shuffle()
打乱图片和标签。
5.从temp
中取出乱序后的image_list
和label_list
列向量并返回。
代码如下:
#读取数据和标签
def get_files(file_dir):
cats = []
label_cats = []
dogs = []
label_dogs = []
for file in os.listdir(file_dir): #返回文件名
name = file.split(sep='.') #文件名按.分割
if name[0]=='cat': #如果是cat,标签为0,dog为1
cats.append(file_dir + file)
label_cats.append(0)
else:
dogs.append(file_dir + file)
label_dogs.append(1)
print('There are %d cats\nThere are %d dogs' %(len(cats), len(dogs))) #打印猫和狗的数量
image_list = np.hstack((cats, dogs))
label_list = np.hstack((label_cats, label_dogs))
temp = np.array([image_list, label_list])
temp = temp.transpose()
np.random.shuffle(temp) #打乱图片
image_list = list(temp[:, 0])
label_list = list(temp[:, 1])
label_list = [int(i) for i in label_list] #将label_list中的数据类型转为int型
return image_list, label_list
1.2 get_batch()
由于数据集较大,需要分批次通过网络。get_batch()
就是用于将图片划分批次。
主要流程:
1.image
和label
为list类型,转换为TensorFlow可以识别的tensor格式。
2.使用tf.train.slice_input_producer()
将image
和label
合并生成一个队列,然后从队列中分别取出image
和label
。其中image
需要使用tf.image.decode_jpeg()
进行解码,由于图片大小不统一,使用tf.image.resize_image_with_crop_or_pad()
进行裁剪/扩充,最后使用tf.image.per_image_standardization()
进行标准化,此时的image
的shape为[208 208 3]。
3.因为之前已经进行了乱序,使用tf.train.batch()
生成批次,最后得到的image_batch
和label_batch
的shape分别为[1 208 208 3]和[1]。
4.这里原作者代码中对label_batch
又进行reshape,是多余的,删除后无影响。最终返回image_batch
和label_batch
。
Maples丶丶的博客提到,原代码的get_batch()
中使用tf.image.resize_image_with_crop_or_pad
效果欠佳,这种方法是从图像中心向四周裁剪,当图片超过规定尺寸时,只保留规定尺寸的中心区域,就会造成裁剪后的图片中只有狗或猫的一部分躯干,如下图,影响最终的训练结果。