新手一枚,记录一下学习的艰辛历程,如果有什么错误,欢迎大家多多指教。根据所学知识,神经网络是允许图片直接输入进行训练的,可是当输入数据集太大时就不大方便啦,一般是生成tfrecords的形式进行训练,实践证明,这样做真的很方便!所以猫狗的识别第一步就是将数据图片生成TF文档。本人使用的是数据集来源于Kaggle,数据集有12500只猫和12500只狗。数据集可以去网站下载~
一 TFrecords的生成
首先下载好的数据集中的train文件夹中的猫狗分开,形成下图中的文件夹
接下来不多说附上程序代码
import os
import tensorflow as tf
from PIL import Image
cwd='E:\\BaiduNetdiskDownload\kaggle\\train\\'#上述文件夹地址
classes={
'cat','dog'}
writer= tf.python_io.TFRecordWriter("cat_dog.tfrecords")
for index,name in enumerate(classes):
class_path=cwd+name+'/'
for img_name in os.listdir(class_path):
img_path=class_path+img_name
img=Image.open(img_path)
img= img.resize((64,64))
img_raw=img.tobytes()
#plt.imshow(img) # if you want to check you image,please delete '#'
#plt.show()
example = tf.train.Example(features=tf.train.Features(feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
}))
writer.write(example.SerializeToString())
writer.close()
二 TF文档的读取
新建input_data.py文件,返回图像和标签
import tensorflow as tf
def read_and_decode(tfrecords_file): # read iris_contact.tfrecords
filename_queue = tf.train.string_input_producer([tfrecords_file])# create a queue
#队列生成
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)#return file_name and file
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw' : tf.FixedLenFeature([], tf.string),
})#return image and label
img = tf.decode_raw(features['img_raw'], tf.uint8)
img = tf.reshape(img, [64, 64, 3]) #reshape image to 512*80*3
img = tf.cast(img, tf.f