tensorflow的TFrecords读取数据

tensorflow的TFrecords读取数据

代码块

使用tensorflow自带的读取数据的方式来读取(有时间再写一个详细的),例如:

# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import cv2
import numpy as np
import random 
class Image_data():
    def __init__(self):
        self.read_data()
        self.gen_tf()
    def read_data(self):
        self.img_data =[]
        self.img_label_data=[]
        self.diff_data=[]
        self.diff_label_data=[]
        self.data_dict=[]
        ann_path='I:/transfer/data/anns/'
        img_path='I:/transfer/data/rec/'
        for i in range(11):
            #print(os.listdir(path+str(i)))
            img_name=os.listdir(ann_path+str(i))
            for elem in img_name:
                self.data_dict.append({'ann':ann_path+str(i)+'/'+elem+'/','img':img_path+str(i)+'/'+elem.split('_')[0]+'_ori/'})
                self.img_data.append(img_path+str(i)+'/'+elem.split('_')[0]+'_ori/'+'img.jpg')
                self.img_label_data.append(img_path+str(i)+'/'+elem.split('_')[0]+'_ori/'+'label.npy')
                self.diff_data.append(ann_path+str(i)+'/'+elem+'/'+'img.jpg')
                self.diff_label_data.append(ann_path+str(i)+'/'+elem+'/'+'label.npy')
        random.shuffle(self.data_dict)
        length=len(self.data_dict)
        self.train_data=self.data_dict[0:int(length*0.9)]
        self.test_data=self.data_dict[int(length*0.9):-1]
        print(len(self.train_data))
        print(len(self.test_data))


    def get_single_data(self,queue):

        img=tf.image.decode_jpeg(tf.read_file(queue[0]),channels=3)
        img_r, img_g, img_b = tf.split(value=img, num_or_size_splits=3, axis=2)
        img = tf.cast(tf.concat([img_b, img_g, img_r], 2), dtype=tf.uint8)

        # img_label=tf.image.decode_jpeg(tf.read_file(queue[1]),channels=1)
        # img_label=tf.image.resize_images(img_label,(192,256))

        diff=tf.image.decode_jpeg(tf.read_file(queue[1]),channels=3)
        diff_r, diff_g, diff_b = tf.split(value=diff, num_or_size_splits=3, axis=2)
        diff = tf.cast(tf.concat([diff_b, diff_g, diff_r], 2), dtype=tf.uint8)

        # diff_label=tf.image.decode_jpeg(tf.read_file(queue[3]),channels=1)
        # diff_label=tf.image.resize_images(diff_label,(192,256))

        return img,diff
    def read_and_decode(self):

        queue=tf.train.slice_input_producer([self.img_data,self.diff_data],shuffle=False)
        self.img,self.diff=self.get_single_data(queue)
    def get_batch(self,batch_size):
        img,diff=tf.train.shuffle_batch([self.img,self.diff],batch_size=batch_size,shapes=[(192, 256, 3),(192, 256, 3)],capacity=batch_size*11,min_after_dequeue=10,num_threads=4)#,shapes=[(192, 256, 3),(192, 256, 3)]
        return  img,diff






    def gen_tf(self):
        writer = tf.python_io.TFRecordWriter("I:/transfer/plan9_small_11/train.tfrecords")
        for i in range(len(self.data_dict)):
            data_dict=self.data_dict[i]
            img=cv2.imread(data_dict['img']+'img.jpg').tobytes()
            img_label=(np.load(data_dict['img']+'label.npy').astype(np.uint8)).tobytes()#坑,一定要数据类型一致,否则在解析的时候得不到数据
            diff=cv2.imread(data_dict['ann']+'img.jpg').tobytes()
            diff_label=(np.load(data_dict['ann']+'label.npy').astype(np.uint8)).tobytes()
            example = tf.train.Example(features=tf.train.Features(feature={
                "img": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
                'img_label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_label])),
                "diff": tf.train.Feature(bytes_list=tf.train.BytesList(value=[diff])),
                'diff_label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[diff_label]))
            }))
            writer.write(example.SerializeToString())  #序列化为字符串
        print('over')
        writer.close()
    def read_and_decode_tf(self,filename):
        #根据文件名生成一个队列
        filename_queue = tf.train.string_input_producer([filename])

        reader = tf.TFRecordReader()
        _, serialized_example = reader.read(filename_queue)   #返回文件名和文件
        features = tf.parse_single_example(serialized_example,
                                           features={
                                               'img': tf.FixedLenFeature([], tf.string),
                                               'img_label' : tf.FixedLenFeature([], tf.string),
                                               'diff': tf.FixedLenFeature([], tf.string),
                                               'diff_label' : tf.FixedLenFeature([], tf.string),
                                           })

        self.img = tf.decode_raw(features['img'], tf.uint8)
        self.img=tf.reshape(self.img,[192,256,3])
        self.diff=tf.decode_raw(features['diff'], tf.uint8)
        self.diff=tf.reshape(self.diff,[192,256,3])       
        self.img_label = tf.decode_raw(features['img_label'], tf.uint8)
        self.img_label=tf.reshape(self.img_label,[192,256,1])
        self.diff_label=tf.decode_raw(features['diff_label'], tf.uint8)
        self.diff_label=tf.reshape(self.diff_label,[192,256,1])
        self.diff_label=tf.concat([self.diff_label,self.diff_label,self.diff_label],-1)
        self.diff_label=self.diff_label*225
        #img = tf.reshape(img, [224, 224, 3])
        #img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
        #label = tf.cast(features['label'], tf.int32)
    def get_batch_tf(self,batch_size):
        img,diff,img_label,diff_label=tf.train.batch([self.img,self.diff,self.img_label,self.diff_label],batch_size=batch_size,capacity=batch_size*11,num_threads=4)#,shapes=[(192, 256, 3),(192, 256, 3),(192, 256,1),(192, 256,3)]
        return  img,diff,img_label,diff_label





if __name__=='__main__':
    pro=Image_data()
    pro.read_and_decode_tf("I:/transfer/plan9_small_11/train.tfrecords")
    img,diff,img_label,diff_label=pro.get_batch_tf(30)
    sess=tf.InteractiveSession()
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess,coord=coord)
    for i in range(20):
        img_,diff_ ,img_label_,diff_label_= sess.run([img,diff,img_label,diff_label])
        for j in range(30):
            cv2.imshow('img',img_[j])
            cv2.imshow('diff',diff_[j])
            cv2.imshow('diff_label',diff_label_[j])
            cv2.waitKey(0)

    coord.request_stop()
    coord.join(threads)
    sess.close()

    # pro=Image_data()
    # pro.read_and_decode()
    # img,diff=pro.get_batch(30)
    # sess=tf.InteractiveSession()
    # tf.train.start_queue_runners(sess=sess)
    # for i in range(20):
    #     for j in range(30):
    #         img_,diff_ = sess.run([img,diff])
    #         cv2.imshow('img',img_[j])
    #         cv2.imshow('diff',diff_[j])
    #         cv2.waitKey(0)


评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值