用自己的数据集训练深度学习网络,需要将数据集加载到网络中去,这里我们把数据集转化成tfrecord格式来进行加载。
TFRecord 数据文件是一种将图像和标签统一存储的 二进制文件,它能更好的利用内存,在 TensorFlow 中快速的复制,移动,存储,读取等。
1. 整理好自己的数据集。将某一类图片放到一个文件夹内,如下图所示六类图片:
{'bo_luo_tai', 'chi_hen', 'dian_ci','hou_tai','lie_wen','zheng_chang'}
图片大小为256*256,保存地址为E: \Tongue_tfrecord_data\Ping_heng_Yuan_tongue_photes
2. 制作 d TFRecord 文件
TFRecord 会根据你选着输入的文件类,自动给每一类打上同样的标
签。在本列中,有 0,1,2 ,3,4,5六类。((二)、中会介绍将标签转化为one_hot独热编码格式)
3. python代码如下:
import os
import tensorflow as tf
from PIL import Image # 注意Image,后面会用到
import matplotlib.pyplot as plt
import numpy as np
import sys
import urllib
# 数据集的实际地址
# 文件夹的命名应避开关键词,否则会提示找不到路径
cwd = 'E:\Etodesktop\ML\SY_CNN_model_Tongue_phengheng\Tongue_tfrecord_data\Ping_heng_Yuan_tongue_photes\\'
# 人为的将数据设定为六类
classes = {'bo_luo_tai', 'chi_hen', 'dian_ci','hou_tai','lie_wen','zheng_chang'}
# 要生成的XX.TFRecord格式文件
writer = tf.python_io.TFRecordWriter("train.tfrecords")
#--------------------------------------------把数据存储为TFRrcord格式--------------------------------------------------#
'''
注意:
"label"和'img_raw'的名称和数据类型在生成TFRecord代码和将TFRecord读出时的代码一样
'''
for index, name in enumerate(classes): # enumerate()可以同时获得索引和元素
class_path = cwd + name + '\\'
for img_name in os.listdir(class_path): # os.listdir 返回指定目录下所有的文件和目录
img_path = class_path + img_name # 每一个图片的地址
img = Image.open(img_path)
img = img.resize((256, 256)) # 图片设置为256*256
img_raw = img.tobytes() # 将每一个图片转化为二进制格式
example = tf.train.Example(
features=tf.train.Features(
feature={
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[index])), # index 为标签
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw]))
})) # example对象对label和image数据进行封装
writer.write(example.SerializeToString()) # 序列化为字符串
writer.close()
4. 生成的TFRecord数据格式如下图所示: