将自己的数据集制作成TFRecord格式

在使用TensorFlow训练神经网络时,首先面临的问题是:网络的输入

此篇文章,教大家将自己的数据集制作成TFRecord格式,feed进网络,除了TFRecord格式,TensorFlow也支持其他格

式的数据,此处就不再介绍了。建议大家使用TFRecord格式,在后面可以通过api进行多线程的读取文件队列。

1. 原本的数据集

此时,我有两类图片,分别是cat,dog,每一类中有10张图片。

 

2.制作成TFRecord格式

tfrecord会根据你选择输入文件的类,自动给每一类打上同样的标签。如在本例中,只有0,1 两类,想知道文件夹名与label关系的,可以自己保存起来。

#生成整数型的属性
def _int64_feature(value):
    return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))
 
#生成字符串类型的属性
def _bytes_feature(value):
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))
 
#制作TFRecord格式
def createTFRecord(filename,mapfile):
    class_map = {}
    data_dir = 'C:/Users/lenovo/Desktop/data/'
    classes = {'cat','dog'}
    #输出TFRecord文件的地址
  
    writer = tf.python_io.TFRecordWriter(filename)
 
    for index,name in enumerate(classes):
        class_path=data_dir+name+'/'
        class_map[index] = name
        for img_name in os.listdir(class_path):
            img_path = class_path + img_name   #每个图片的地址
            img = Image.open(img_path)
            img= img.resize((224,224))
            img_raw = img.tobytes()          #将图片转化成二进制格式
            example = tf.train.Example(features = tf.train.Features(feature = {
                'label':_int64_feature(index),
                'image_raw': _bytes_feature(img_raw)
            }))
            writer.write(example.SerializeToString())
    writer.close()
    
    txtfile = open(mapfile,'w+')
    for key in class_map.keys():
        txtfile.writelines(str(key)+":"+class_map[key]+"\n")
    txtfile.close()

 此段代码,运行完后会产生生成的.tfrecord文件和.txt文件。

3. 读取TFRecord的数据,进行解析,此时使用了文件队列以及多线程

#读取train.tfrecord中的数据
def read_and_decode(filename):   
    #创建一个reader来读取TFRecord文件中的样例
    reader = tf.TFRecordReader()
    #创建一个队列来维护输入文件列表
    filename_queue = tf.train.string_input_producer([filename], shuffle=False,num_epochs = 1)
    #从文件中读出一个样例,也可以使用read_up_to一次读取多个样例
    _,serialized_example = reader.read(filename_queue)
#     print _,serialized_example
 
    #解析读入的一个样例,如果需要解析多个,可以用parse_example
    features = tf.parse_single_example(
    serialized_example,
    features = {'label':tf.FixedLenFeature([], tf.int64),
               'image_raw': tf.FixedLenFeature([], tf.string),})
    #将字符串解析成图像对应的像素数组
    img = tf.decode_raw(features['image_raw'], tf.uint8)
    img = tf.reshape(img,[224, 224, 3]) #reshape为128*128*3通道图片
    img = tf.image.per_image_standardization(img)
    labels = tf.cast(features['label'], tf.int32)
    return img, labels

 

4. 将图片几个一打包,形成batch

def createBatch(filename,batchsize):
    images,labels = read_and_decode(filename)
   
    min_after_dequeue = 10
    capacity = min_after_dequeue + 3 * batchsize
 
    image_batch, label_batch = tf.train.shuffle_batch([images, labels], 
                                                        batch_size=batchsize, 
                                                        capacity=capacity, 
                                                        min_after_dequeue=min_after_dequeue
                                                        )
 
    label_batch = tf.one_hot(label_batch,depth=2)
    return image_batch, label_batch

 

5.主函数

if __name__ =="__main__":
    #训练图片两张为一个batch,进行训练,测试图片一起进行测试
    mapfile = "C:/Users/lenovo/Desktop/data/classmap.txt"
    train_filename = "C:/Users/lenovo/Desktop/data/train.tfrecords"
#    mapfile = "C:/Users/lenovo/Desktop/data"
#    train_filename = "C:/Users/lenovo/Desktop/data"
    createTFRecord(train_filename,mapfile)
    test_filename = "C:/Users/lenovo/Desktop/data/test.tfrecords"
    createTFRecord(test_filename,mapfile)
    image_batch, label_batch = createBatch(filename = train_filename,batchsize = 2)
    test_images,test_labels = createBatch(filename = test_filename,batchsize = 20)
    with tf.Session() as sess:
        initop = tf.group(tf.global_variables_initializer(),tf.local_variables_initializer())
        sess.run(initop)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess = sess, coord = coord)
 
        try:
            step = 0
            while 1:
                _image_batch,_label_batch =  sess.run([image_batch,label_batch])
                step += 1
                print (step)
                print (_image_batch.shape)
                print (_label_batch)
        except tf.errors.OutOfRangeError:
            print (" trainData done!")
            
        try:
            step = 0
            while 1:
                _test_images,_test_labels =  sess.run([test_images,test_labels])
                step += 1
                print (step)
                print ( _test_images.shape)
    #                 print _image_batch.shape
                print (_test_labels)
        except tf.errors.OutOfRangeError:
            print (" TEST done!")
        coord.request_stop()
        coord.join(threads)

 

此时,生成的batch,就可以feed进网络了。

 

  • 0
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值