制作自己的tfrecords数据集

记录一下去年用mobilenet训练活体识别制作tfrecords数据的过程:

1.制作数据list文件,生成txt文件,形式如下:

C:/Users/12544/Documents/train\P\image_19_depth.xml 0
C:/Users/12544/Documents/train\P\image_17_depth.xml 0
C:/Users/12544/Documents/train\N\image_6_depth.xml 1
C:/Users/12544/Documents/train\P\image_14_depth.xml 0
C:/Users/12544/Documents/train\P2\image_24_depth.xml 0
C:/Users/12544/Documents/train\N\image_5_depth.xml 1
C:/Users/12544/Documents/train\P2\image_21_depth.xml 0
C:/Users/12544/Documents/train\N2\image_13_depth.xml 1
C:/Users/12544/Documents/train\P2\image_23_depth.xml 0
C:/Users/12544/Documents/train\N2\image_11_depth.xml 1
C:/Users/12544/Documents/train\N\image_2_depth.xml 1
C:/Users/12544/Documents/train\N\image_0_depth.xml 1
C:/Users/12544/Documents/train\N\image_1_depth.xml 1
C:/Users/12544/Documents/train\N\image_8_depth.xml 1
C:/Users/12544/Documents/train\P\image_18_depth.xml 0
C:/Users/12544/Documents/train\P\image_15_depth.xml 0
C:/Users/12544/Documents/train\N2\image_12_depth.xml 1
C:/Users/12544/Documents/train\N\image_7_depth.xml 1
C:/Users/12544/Documents/train\P\image_16_depth.xml 0
C:/Users/12544/Documents/train\P2\image_26_depth.xml 0
C:/Users/12544/Documents/train\N2\image_9_depth.xml 1
C:/Users/12544/Documents/train\N\image_3_depth.xml 1
C:/Users/12544/Documents/train\N\image_4_depth.xml 1
C:/Users/12544/Documents/train\N2\image_10_depth.xml 1
C:/Users/12544/Documents/train\P2\image_25_depth.xml 0
C:/Users/12544/Documents/train\P\image_20_depth.xml 0
C:/Users/12544/Documents/train\P2\image_22_depth.xml 0

代码如下:

import os
import random

def get_list(path):
    classes = sorted(os.walk(path).__next__()[1])  #获取目录下的文件夹名
    file = []  #建立一个文件列表
    list = open('list.txt','w')  #要生成的list文件
    for c in classes:
        if c in ('N','N2'):
            name = 1
        if c in ('P','P2'):
            name = 0
        c_dir = os.path.join(path, c)
        for root, dirs, files in os.walk(c_dir):
            for filename in files:
                if filename.endswith('.xml'):
                    img_path = os.path.join(root,filename)
                    print ('文件名为:%s, 标签名为:%d' %(img_path,name))
                    file.append(img_path + ' ' + str(name))
    random.shuffle(file)
    print('总训练样本数:', len(file))

    for i in range(len(file)):
        line = str(file[i] + '\n')
        list.write(line)

if __name__=='__main__':
    imgpath = 'C:/Users/12544/Documents/train'  #这个目录下有N、N2、P、P2四个文件夹,N、N2里面是负样本,P、P2里面是正样本。
    get_list(imgpath)

2.制作tfrecords数据集

#conding = utf-8

import os
import tensorflow as tf
import re
import numpy as np
import xml.dom.minidom

def read_xml(filepath):
    dom = xml.dom.minidom.parse(filepath)
    datas = dom.getElementsByTagName('data')  #读取xml文件中'data'标签内的数据
    data = datas[0]
    f = data.firstChild.data
    filedata = re.findall("\d+", f)
    data = np.array(filedata)   #转成numpy矩阵
    return data

writer = tf.python_io.TFRecordWriter('train_float_1chanel.tfrecords')
with open('./list.txt') as f:
    content = f.readlines()  #按行读取txt文件
    sum = 0
    for line in content:
        filename = re.split('\s+', line)[0]
        label = int(re.split('\s+', line)[1])
        print ('文件名为:%s, 标签名为:%d' %(filename,label))
        filedata = read_xml(filename)
        # data = filedata.repeat(3) #复制数据里的每个元素三次
        # data = data.reshape(128,128,3) #将数据重新reshape,这里是模拟将单通道图转成三通道图
        data = filedata.reshape(128,128)
        data = data.astype(np.float32)
        img_raw = data.tobytes()  #将图片转化为原生bytes
        example = tf.train.Example(features=tf.train.Features(feature={
            'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
            'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
        }))
        writer.write(example.SerializeToString())  #序列化为字符串
        sum = sum + 1
    writer.close()
    print ('总训练样本数:',sum)

3.验证制作的数据集是否可用

import tensorflow as tf
from PIL import Image

tfrecords_filename = 'C:/Users/12544/PycharmProjects/tools/train_float_1chanel.tfrecords'
filename_queue = tf.train.string_input_producer([tfrecords_filename])  #读入数据流
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)  #返回文件名和文件
features = tf.parse_single_example(serialized_example,
                                   features={
                                       'label': tf.FixedLenFeature([],tf.int64),
                                       'img_raw': tf.FixedLenFeature([],tf.string),
                                   })  #取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'],tf.float32)
image = tf.reshape(image,[128,128,1])
label = tf.cast(features['label'],tf.int64)

with tf.Session() as sess:  #开始一个会话
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    for i in range(10):
        example, la = sess.run([image, label]) #在会话中取出image和label
        # img = Image.fromarray(example,'RGB')
        # img.save('./' + str(i) + '_''Label_' + str(1) + '.jpg')
        print(la)

    coord.request_stop()
    coord.join(threads)

这里会开启一个会话,读取tfrecords数据集,返回数据集里面的10个样本标签(注释了读取样本数据,另存图片)

 

记录一下之前遇到的一个坑:之所以做txt文件,主要是为了shuffle,这样做出的数据集就是shuffle过的。虽然tensorflow有tf.train.shuffle_batch。但是我之前用的时候,这个函数没效果。

 

 
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值