记录一下去年用mobilenet训练活体识别制作tfrecords数据的过程:
1.制作数据list文件,生成txt文件,形式如下:
C:/Users/12544/Documents/train\P\image_19_depth.xml 0
C:/Users/12544/Documents/train\P\image_17_depth.xml 0
C:/Users/12544/Documents/train\N\image_6_depth.xml 1
C:/Users/12544/Documents/train\P\image_14_depth.xml 0
C:/Users/12544/Documents/train\P2\image_24_depth.xml 0
C:/Users/12544/Documents/train\N\image_5_depth.xml 1
C:/Users/12544/Documents/train\P2\image_21_depth.xml 0
C:/Users/12544/Documents/train\N2\image_13_depth.xml 1
C:/Users/12544/Documents/train\P2\image_23_depth.xml 0
C:/Users/12544/Documents/train\N2\image_11_depth.xml 1
C:/Users/12544/Documents/train\N\image_2_depth.xml 1
C:/Users/12544/Documents/train\N\image_0_depth.xml 1
C:/Users/12544/Documents/train\N\image_1_depth.xml 1
C:/Users/12544/Documents/train\N\image_8_depth.xml 1
C:/Users/12544/Documents/train\P\image_18_depth.xml 0
C:/Users/12544/Documents/train\P\image_15_depth.xml 0
C:/Users/12544/Documents/train\N2\image_12_depth.xml 1
C:/Users/12544/Documents/train\N\image_7_depth.xml 1
C:/Users/12544/Documents/train\P\image_16_depth.xml 0
C:/Users/12544/Documents/train\P2\image_26_depth.xml 0
C:/Users/12544/Documents/train\N2\image_9_depth.xml 1
C:/Users/12544/Documents/train\N\image_3_depth.xml 1
C:/Users/12544/Documents/train\N\image_4_depth.xml 1
C:/Users/12544/Documents/train\N2\image_10_depth.xml 1
C:/Users/12544/Documents/train\P2\image_25_depth.xml 0
C:/Users/12544/Documents/train\P\image_20_depth.xml 0
C:/Users/12544/Documents/train\P2\image_22_depth.xml 0
代码如下:
import os
import random
def get_list(path):
classes = sorted(os.walk(path).__next__()[1]) #获取目录下的文件夹名
file = [] #建立一个文件列表
list = open('list.txt','w') #要生成的list文件
for c in classes:
if c in ('N','N2'):
name = 1
if c in ('P','P2'):
name = 0
c_dir = os.path.join(path, c)
for root, dirs, files in os.walk(c_dir):
for filename in files:
if filename.endswith('.xml'):
img_path = os.path.join(root,filename)
print ('文件名为:%s, 标签名为:%d' %(img_path,name))
file.append(img_path + ' ' + str(name))
random.shuffle(file)
print('总训练样本数:', len(file))
for i in range(len(file)):
line = str(file[i] + '\n')
list.write(line)
if __name__=='__main__':
imgpath = 'C:/Users/12544/Documents/train' #这个目录下有N、N2、P、P2四个文件夹,N、N2里面是负样本,P、P2里面是正样本。
get_list(imgpath)
2.制作tfrecords数据集
#conding = utf-8
import os
import tensorflow as tf
import re
import numpy as np
import xml.dom.minidom
def read_xml(filepath):
dom = xml.dom.minidom.parse(filepath)
datas = dom.getElementsByTagName('data') #读取xml文件中'data'标签内的数据
data = datas[0]
f = data.firstChild.data
filedata = re.findall("\d+", f)
data = np.array(filedata) #转成numpy矩阵
return data
writer = tf.python_io.TFRecordWriter('train_float_1chanel.tfrecords')
with open('./list.txt') as f:
content = f.readlines() #按行读取txt文件
sum = 0
for line in content:
filename = re.split('\s+', line)[0]
label = int(re.split('\s+', line)[1])
print ('文件名为:%s, 标签名为:%d' %(filename,label))
filedata = read_xml(filename)
# data = filedata.repeat(3) #复制数据里的每个元素三次
# data = data.reshape(128,128,3) #将数据重新reshape,这里是模拟将单通道图转成三通道图
data = filedata.reshape(128,128)
data = data.astype(np.float32)
img_raw = data.tobytes() #将图片转化为原生bytes
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
}))
writer.write(example.SerializeToString()) #序列化为字符串
sum = sum + 1
writer.close()
print ('总训练样本数:',sum)
3.验证制作的数据集是否可用
import tensorflow as tf
from PIL import Image
tfrecords_filename = 'C:/Users/12544/PycharmProjects/tools/train_float_1chanel.tfrecords'
filename_queue = tf.train.string_input_producer([tfrecords_filename]) #读入数据流
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([],tf.int64),
'img_raw': tf.FixedLenFeature([],tf.string),
}) #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'],tf.float32)
image = tf.reshape(image,[128,128,1])
label = tf.cast(features['label'],tf.int64)
with tf.Session() as sess: #开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(10):
example, la = sess.run([image, label]) #在会话中取出image和label
# img = Image.fromarray(example,'RGB')
# img.save('./' + str(i) + '_''Label_' + str(1) + '.jpg')
print(la)
coord.request_stop()
coord.join(threads)
这里会开启一个会话,读取tfrecords数据集,返回数据集里面的10个样本标签(注释了读取样本数据,另存图片)
记录一下之前遇到的一个坑:之所以做txt文件,主要是为了shuffle,这样做出的数据集就是shuffle过的。虽然tensorflow有tf.train.shuffle_batch。但是我之前用的时候,这个函数没效果。