使用Tensorflow创建自己的数据集,并训练
介绍环境
win10 + pycharm + CPU
介绍背景
要求用卷积神经网络对不同水分的玉米进行分类(最后的目标是实现回归,以后研究),神经网络虽然是科研神器,但是在工业上的应用效果远远不如实验室中的好。我们找到的教程无非是mnist,表情识别,等官方的数据集。对于一个小白来说虽然上手容易,但是收获这得有限。这篇博客我希望把每个知识都讲到尽量的通俗易懂,希望这篇处女作可以给小白指导。文中部分代码参考了ywx1832990,在此感谢。受限于水平,有讲解错误的地方,也欢迎留言探讨。
话不多说 直接上代码
step1:建立两个TFrecords
# pycharm中此模块名为genertateds.py
import os
import tensorflow as tf
from PIL import Image
# 源数据地址
cwd = r'C:\Users\pc\Desktop\orig_picture'
# 生成record路径及文件名
train_record_path =r"C:\Users\pc\Desktop\outputdata\train.tfrecords"
test_record_path =r"C:\Users\pc\Desktop\outputdata\test.tfrecords"
# 分类
classes = {'11.8','13','14.8','16.5','18','20.6','22.8','26.1','28.7','30.6'}
def _byteslist(value):
"""二进制属性"""
return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
def _int64list(value):
"""整数属性"""
return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
def create_train_record():
"""创建训练集tfrecord"""
writer = tf.python_io.TFRecordWriter(train_record_path) # 创建一个writer
NUM = 1 # 显示创建过程(计数)
for index, name in enumerate(classes):
class_path = cwd + "/" + name + '/'
l = int(len(os.listdir(class_path)) * 0.7) # 取前70%创建训练集
for img_name in os.listdir(class_path)[:l]:
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128)) # resize图片大小
img_raw = img.tobytes() # 将图片转化为原生bytes
example = tf.train.Example( # 封装到Example中
features=tf.train.Features(feature={
"label":_int64list(index), # label必须为整数类型属性
'img_raw':_byteslist(img_raw) # 图片必须为二进制属性
}))
writer.write(example.SerializeToString())
print('Creating train record in ',NUM)
NUM += 1
writer.close() # 关闭writer
print("Create train_record successful!")
def create_test_record():
"""创建测试tfrecord"""
writer = tf.python_io.TFRecordWriter(test_record_path)
NUM = 1
for index, name in enumerate(classes):
class_path = cwd + '/' + name + '/'
l = int(len(os.listdir(class_path)) * 0.7)
for img_name in os.listdir(class_path)[l:]: # 剩余30%作为测试集
img_path = class_path + img_name
img = Image.open(img_path)
img = img.resize((128, 128))
img_raw = img.tobytes() # 将图片转化为原生bytes
# print(index,img_raw)
example = tf.train.Example(
features=tf.train.Features(feature={
"label":_int64list(index),
'img_raw':_byteslist(img_raw)
}))
writer.write(example.SerializeToString())
print('Creating test record in ',NUM)
NUM += 1
writer.close()
print("Create test_record successful!")
def read_record(filename):
"""读取tfrecord"""
filename_queue = tf.train.string_input_producer([filename]) # 创建文件队列
reader = tf.TFRecordReader() # 创建reader
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
}
)
label = features['label']
img = features['img_raw']
img = tf.decode_raw(img, tf.uint8)
img = tf.reshape(img, [128, 128, 3])
img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 # 归一化
label = tf.cast(label, tf.int32)
return img, label
def get_batch_record(filename,batch_size):
"""获取batch"""
image,label = read_record(filename)
image_batch,label_batch = tf.train.shuffle_batch([image,label], # 随机抽取batch size个image、label
batch_size=batch_size,
capacity=2000,
min_after_dequeue=1000)
return image_batch,label_batch
def main():
create_train_record()
create_test_record()
if __name__ == '__main__':
main()
这里值得一提的是 from PIL import Image 在jupyter中某次更新后,会出现无法使用的现象。建议使用jupyter的朋友不要更新。如果更新了可以卸载,重新安装之前的版本
注意 windows下 cwd = r’C:\Users\pc\Desktop\orig