关于tfrecords的一点补充

        昨天写完博客(点击打开链接)然后又发现一些问题,就是关于tf.VarLenFeature()的,然后解决了上一篇遗留的问题。

首先说一下上一篇提到要resize的问题,原因竟然是因为tf.VarLenFeature()这个函数,下面会说一下:

        一般来说,tf.VarLenFeature()这个函数是用来处理稀疏张量的,具体介绍可以看官网,但是当生成tfrecords时,存在多个变量数据,且大小未知的时候,解析就需要用到该函数。比如,我的生成数据类型是这样的:

        

def convert_to_example(image,label):
    
    data = Image.open(image)
    img = data
    w = img.width
    h = img.height
    shape = [h,w,3]
    img = img.tobytes()
    
    with open(label,'r',encoding = 'utf-8') as f:
        
        pos = [x.split(',') for x in f.readlines()]
        
    x1 = []
    y1 = []
    x2 = []
    y2 = []
    x3 = []
    y3 = []
    x4 = []
    y4 = []
    label = []
    for p in pos:
        label.append([p[-1].encode()])
        x1.append(float(p[0]))
        y1.append(float(p[1]))
        x2.append(float(p[2]))
        y2.append(float(p[3]))
        x3.append(float(p[4]))
        y3.append(float(p[5]))
        x4.append(float(p[6]))
        y4.append(float(p[7]))
    label = np.array(label).tobytes()            
    
    
    feature = {
            'image':_bytes_list(img),
            'image_shape':_int64_list(shape),
            'x1':_float_list(x1),
            'y1':_float_list(y1),
            'x2':_float_list(x2),
            'y2':_float_list(y2),
            'x3':_float_list(x3),
            'y3':_float_list(y3),
            'x4':_float_list(x4),
            'y4':_float_list(y4),
            'label':_bytes_list(label)
            }
    
    return tf.train.Example(features = tf.train.Features(feature = feature))

        使用的数据依然是上篇博客中的天池比赛数据,像x1,x2这样的,解析的时候就需要用到,tf.VarLenFeature(),

后面会给出全部代码,注意代码中的x1 = tf.sparse_tensor_to_dense(parse_example['x1'])

    x1 = tf.cast(x1,tf.float32)这样的语句,因为tf.VarLenFeature()解析出的是SparseTensor,需要进行转化,但是仅仅只转化张量而不cast,就会报一个叫什么系统内部错误的东西。下面给出全部代码:

        

# -*- coding: utf-8 -*-
"""
Created on Mon Apr 23 19:29:05 2018
tfrecords test
@author: lenovo
"""

import tensorflow as tf
from PIL import Image
import numpy as np


def _bytes_list(value):
    
    if not isinstance(value,list):
        value = [value]
    return tf.train.Feature(bytes_list = tf.train.BytesList(value = value))

def _int64_list(value):
    
    if not isinstance(value,list):
        value = [value]
        
    return tf.train.Feature(int64_list = tf.train.Int64List(value = value))

def _float_list(value):
    
    if not isinstance(value,list):
        value = [value]
    
    return tf.train.Feature(float_list = tf.train.FloatList(value = value))



def convert_to_example(image,label):
    
    data = Image.open(image)
    img = data
    w = img.width
    h = img.height
    shape = [h,w,3]
    img = img.tobytes()
    
    with open(label,'r',encoding = 'utf-8') as f:
        
        pos = [x.split(',') for x in f.readlines()]
        
    x1 = []
    y1 = []
    x2 = []
    y2 = []
    x3 = []
    y3 = []
    x4 = []
    y4 = []
    label = []
    for p in pos:
        label.append([p[-1].encode()])
        x1.append(float(p[0]))
        y1.append(float(p[1]))
        x2.append(float(p[2]))
        y2.append(float(p[3]))
        x3.append(float(p[4]))
        y3.append(float(p[5]))
        x4.append(float(p[6]))
        y4.append(float(p[7]))
    label = np.array(label).tobytes()            
    
    
    feature = {
            'image':_bytes_list(img),
            'image_shape':_int64_list(shape),
            'x1':_float_list(x1),
            'y1':_float_list(y1),
            'x2':_float_list(x2),
            'y2':_float_list(y2),
            'x3':_float_list(x3),
            'y3':_float_list(y3),
            'x4':_float_list(x4),
            'y4':_float_list(y4),
            'label':_bytes_list(label)
            }
    
    return tf.train.Example(features = tf.train.Features(feature = feature))

def convert_to_tfrecords(examples,path):
    
    assert type(examples) == list or type(examples) == tuple
    with tf.python_io.TFRecordWriter(path) as tfrecord_writer:
        i = 0
        for example in examples:
            tfrecord_writer.write(example.SerializeToString())
            i += 1
            print('成功完成%g条数据'%i)
            
def convert_to_tfrecords_(examples,path):
    
    '''
    生成多个文件时可以写出这种形式
    path like 'filename_%s_data_%s_of_%s_.tfrecords'
    '''
    
    assert type(examples) == list or type(examples) == tuple
    assert '%s' in path
    
    count = len(examples)
    tn = 0
    j = 1000
    while count >= tn*j:
        k = 0
        file = path%(j,tn+1,int(count/j)+1)
        with tf.python_io.TFRecordWriter(file) as tfrecord_write:
            for i in range(tn*j,(tn+1)):
                if i >= count:
                    break
                tfrecord_write.write(examples[i].SerializeToString())
                k += 1
                print('第%g个文件的完成%g/%g条'%(tn+1,k,j))
        tn += 1
            
def run(images,labels,path):
    
    examples = []
    for image,label in zip(images,labels):
        examples.append(convert_to_example(image,label))
    convert_to_tfrecords(examples,path)
    
def _parse_of_tfrecords(record):
    
    features = {
            'image':tf.FixedLenFeature([],tf.string),
            'image_shape':tf.FixedLenFeature([3],tf.int64),
            'x1':tf.VarLenFeature(tf.float32),
            'y1':tf.VarLenFeature(tf.float32),
            'x2':tf.VarLenFeature(tf.float32),
            'y2':tf.VarLenFeature(tf.float32),
            'x3':tf.VarLenFeature(tf.float32),
            'y3':tf.VarLenFeature(tf.float32),
            'x4':tf.VarLenFeature(tf.float32),
            'y4':tf.VarLenFeature(tf.float32),
            'label':tf.FixedLenFeature([],tf.string)
            }
    parse_example = tf.parse_single_example(serialized=record,features = features)
    image = tf.decode_raw(parse_example['image'],out_type=tf.uint8)
    shape = tf.cast(parse_example['image_shape'],tf.int32)
    x1 = tf.sparse_tensor_to_dense(parse_example['x1'])
    x1 = tf.cast(x1,tf.float32)
    y1 = tf.sparse_tensor_to_dense(parse_example['y1'])
    y1 = tf.cast(y1,tf.float32)
    x2 = tf.sparse_tensor_to_dense(parse_example['x2'])
    x2 = tf.cast(x2,tf.float32)
    y2 = tf.sparse_tensor_to_dense(parse_example['y2'])
    y2 = tf.cast(y2,tf.float32)
    x3 = tf.sparse_tensor_to_dense(parse_example['x3'])
    x3 = tf.cast(x3,tf.float32)
    y3 = tf.sparse_tensor_to_dense(parse_example['y3'])
    y3 = tf.cast(y3,tf.float32)
    x4 = tf.sparse_tensor_to_dense(parse_example['x4'])
    x4 = tf.cast(x4,tf.float32)
    y4 = tf.sparse_tensor_to_dense(parse_example['y4'])
    y4 = tf.cast(y4,tf.float32)
    x = [x1,x2,x3,x4]
    y = [y1,y2,y3,y4]
   
    label = parse_example['label']
    image = tf.reshape(image,shape = shape)
    return image,shape,label,x,y
def read_test(path):
    
    dataset = tf.data.TFRecordDataset(path)
    dataset = dataset.map(_parse_of_tfrecords)
    dataset = dataset.batch(1).repeat(1)
    
    iterator = dataset.make_one_shot_iterator()
    next_batch = iterator.get_next()
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for i in range(2):
            images,shapes,labels,xs,ys = sess.run(fetches = next_batch)
            print(xs,shapes)
    return images,shapes,labels,xs,ys

        到此,数据算是能够做出来了,至于tensorflow自身的多线程读取数据,目前依然没有解决,还是有些错误,下面也附上代码,希望路过的大佬帮忙解答。

        

def read_for_base(filename):
    
    if not isinstance(filename,list):
        filename = [filename]
        
    filename_queue = tf.train.string_input_producer(filename)
    reader = tf.TFRecordReader()
    _,seralized_example = reader.read(filename_queue)#返回文件名和文件
    features = {}
    features['image'] = tf.FixedLenFeature([],tf.string)
    features['image_shape'] = tf.FixedLenFeature([3],tf.int64)
    features['x1'] = tf.VarLenFeature(tf.float32)
    features['label'] = tf.FixedLenFeature([],tf.string)
    
    paser_example = tf.parse_single_example(seralized_example,features=features)
    
    image = tf.decode_raw(paser_example['image'],tf.uint8)
    shape = tf.cast(paser_example['image_shape'],tf.int32)
    image = tf.reshape(image,shape)
    x = tf.sparse_tensor_to_dense(paser_example['x1'])
    x = tf.cast(x,tf.float32)
    label = paser_example['label']
    
    return image,label,x

def read_for_base_test(filename):
    
    images,labels,xs = read_for_base(filename)
    
#    image_batch,label_batch,xs_batch = tf.train.shuffle_batch(
#            [images,labels,xs],
#            batch_size=1,
#            min_after_dequeue=1,
#            capacity=2
#            )
    coord = tf.train.Coordinator()
    with tf.Session() as sess:
        
        sess.run(tf.global_variables_initializer())
        threads = tf.train.start_queue_runners(sess = sess,coord = coord)
        for i in range(2):
            image,label,xs = sess.run([images,labels,xs])
            print(image.shape,xs)
            
        coord.request_stop()
        coord.join(threads)

        多线程注释掉报的是TypeError: Fetch argument array([88.82, 30.91], dtype=float32) has invalid type <class 'numpy.ndarray'>, must be a string or Tensor. (Can not convert a ndarray into a Tensor or Operation.)

           不注释掉报的是ValueError: All shapes must be fully defined: [TensorShape([Dimension(None), Dimension(None), Dimension(None)]), TensorShape([]), TensorShape([Dimension(None)])]

            总结起来,那大概就是玄学了!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值