写在开始
自己最开始接触python的时候,第一个学会使用的库就是tensorflow,在经历了everyone 都会经历的mnist数据集训练后,就开始想自己做一个图片分类的深度学习,期间也是一波三折,看了很多csdn上的博客,摸索出了自己的数据集制作习惯,用以简单的分类。
数据集下载
数据集的话,一是通过自己下载收集,也可以去很多赛题里下分类好的资源,像阿里的天池大数据比赛,我用的数据集来自天池大数据数据集
数据集制作
tensorflow只能读取二进制数据,以tfrecord格式保存图片,所以我们必须先把矩阵型的图片数据转化。
import os
import tensorflow as tf
from PIL import Image
import numpy as np
data_path= 'data\\guangdong_round1_train2_20180916'#也可以输入你自己的test路径
writer=tf.python_io.TFRecordWriter('data\\train_data.tfrecord')
classes=['正常',
'不导电',
'擦花',
'横条压凹',
'桔皮',
'漏底',
'碰伤',
'起坑',
'凸粉',
'涂层开裂',
'脏点',
'其他'
]#也可以输入你自己的分类,这是比赛时要求的分类
class_path=[]
for index,name in enumerate(classes):
if name=='正常':
class_path.append(os.path.join(data_path,'无瑕疵样本'))
else:
if name=='其他':
new_path=os.path.join(data_path,'瑕疵样本','其他')
for newpath in os.listdir(new_path):
class_path.append(os.path.join(data_path,'瑕疵样本',name,newpath))
else:
class_path.append(os.path.join(data_path,'瑕疵样本',name))
#上面的过程是为每个类别的图片建一个收索路径,使每一张图片进入进入这个搜索路径,方便下文遍历搜索
#下一个遍历是让每张图片转换为tfrecord格式
for path in class_path:
for image_path in os.listdir(path):
image_path=os.path.join(path,image_path)
image= Image.open(image_path)
image=image.resize((256,256)) #图片太大,根据你的需求缩小尺寸
image_raw=image.tobytes()
example=tf.train.Example(features=tf.train.Features(feature={
'lable':tf.train.Feature(int64_list=tf.train.Int64List(value=[index])) ,
'image_raw':tf.train.Feature(byt