刚刚开始接触TensorFlow,在网上找到一个老师的视频教学,关于利用简单的网络结构实现二分类,也就是猫狗大战的例子。于是对着视屏,自己做了如下的记录。如果有什么地方错误,欢迎大神指点!!!
1.准备数据集。https://www.kaggle.com/c/dogs-vs-cats 在Kaggle的官网上可以下载,如果想直接要数据的童鞋,可以私信我,我发给你,也是可以的。
2.程序结构如图:
3.下面就是关键代码:
input_data.py:
import tensorflow as tf
import numpy as np
import os
img_width=208
img_height=208
#train_dir='G:/PycharmProjects/untitled2/data/train/'
def get_files(file_dir):
cats=[]
label_cats=[]
dogs=[]
label_dogs=[]
for file in os.listdir(file_dir):
name=file.split(sep='.')#文件是cat.1.jpg形式
if name[0]=='cat':
cats.append(file_dir+file)
label_cats.append(0)
else:
dogs.append(file_dir+file)
label_dogs.append(1)
image_list=np.hstack((cats,dogs))
label_list=np.hstack((label_cats,label_dogs))
temp=np.array([image_list,label_list])
temp=temp.transpose()
np.random.shuffle(temp)#打乱顺序
image_list=list(temp[:,0])
label_list=list(temp[:,1])
label_list=[int(i) for i in label_list]
return image_list,label_list
# 生成相同大小的批次
def get_batch(image,label,image_w,image_h,batch_size,capacity):
image=tf.cast(image,tf.string)
label=tf.cast(label,tf.int32)
input_queue=tf.train.slice_input_producer([image,label])
label=input_queue[1]
image_contents=tf.read_file(input_queue[0])
image=tf.image.decode_jpeg(image_contents,channels=3) #解码jpg图片
image=tf.image.resize_image_with_crop_or_pad(image,image_w,image_h) #图片过大 从中间裁剪
image=tf.imag