tensorflow的TFrecords读取数据
代码块
使用tensorflow自带的读取数据的方式来读取(有时间再写一个详细的),例如:
# -*- coding: utf-8 -*-
import tensorflow as tf
import os
import cv2
import numpy as np
import random
class Image_data():
def __init__(self):
self.read_data()
self.gen_tf()
def read_data(self):
self.img_data =[]
self.img_label_data=[]
self.diff_data=[]
self.diff_label_data=[]
self.data_dict=[]
ann_path='I:/transfer/data/anns/'
img_path='I:/transfer/data/rec/'
for i in range(11):
#print(os.listdir(path+str(i)))
img_name=os.listdir(ann_path+str(i))
for elem in img_name:
self.data_dict.append({'ann':ann_path+str(i)+'/'+elem+'/','img':img_path+str(i)+'/'+elem.split('_')[0]+'_ori/'})
self.img_data.append(img_path+str(i)+'/'+elem.split('_')[0]+'_ori/'+'img.jpg')
self.img_label_data.append(img_path+str(i)+'/'+elem.split('_')[0]+'_ori/'+'label.npy')
self.diff_data.append(ann_path+str(i)+'/'+elem+'/'+'img.jpg')
self.diff_label_data.append(ann_path+str(i)+'/'+elem+'/'+'label.npy')
random.shuffle(self.data_dict)
length=len(self.data_dict)
self.train_data=self.data_dict[0:int(length*0.9)]
self.test_data=self.data_dict[int(length*0.9):-1]
print(len(self.train_data))
print(len(self.test_data))
def get_single_data(self,queue):
img=tf.image.decode_jpeg(tf.read_file(queue[0]),channels=3)
img_r, img_g, img_b = tf.split(value=img, num_or_size_splits=3, axis=2)
img = tf.cast(tf.concat([img_b, img_g, img_r], 2), dtype=tf.uint8)
# img_label=tf.image.decode_jpeg(tf.read_file(queue[1]),channels=1)
# img_label=tf.image.resize_images(img_label,(192,256))
diff=tf.image.decode_jpeg(tf.read_file(queue[1]),channels=3)
diff_r, diff_g, diff_b = tf.split(value=diff, num_or_size_splits=3, axis=2)
diff = tf.cast(tf.concat([diff_b, diff_g, diff_r], 2), dtype=tf.uint8)
# diff_label=tf.image.decode_jpeg(tf.read_file(queue[3]),channels=1)
# diff_label=tf.image.resize_images(diff_label,(192,256))
return img,diff
def read_and_decode(self):
queue=tf.train.slice_input_producer([self.img_data,self.diff_data],shuffle=False)
self.img,self.diff=self.get_single_data(queue)
def get_batch(self,batch_size):
img,diff=tf.train.shuffle_batch([self.img,self.diff],batch_size=batch_size,shapes=[(192, 256, 3),(192, 256, 3)],capacity=batch_size*11,min_after_dequeue=10,num_threads=4)#,shapes=[(192, 256, 3),(192, 256, 3)]
return img,diff
def gen_tf(self):
writer = tf.python_io.TFRecordWriter("I:/transfer/plan9_small_11/train.tfrecords")
for i in range(len(self.data_dict)):
data_dict=self.data_dict[i]
img=cv2.imread(data_dict['img']+'img.jpg').tobytes()
img_label=(np.load(data_dict['img']+'label.npy').astype(np.uint8)).tobytes()#坑,一定要数据类型一致,否则在解析的时候得不到数据
diff=cv2.imread(data_dict['ann']+'img.jpg').tobytes()
diff_label=(np.load(data_dict['ann']+'label.npy').astype(np.uint8)).tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
"img": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
'img_label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_label])),
"diff": tf.train.Feature(bytes_list=tf.train.BytesList(value=[diff])),
'diff_label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[diff_label]))
}))
writer.write(example.SerializeToString()) #序列化为字符串
print('over')
writer.close()
def read_and_decode_tf(self,filename):
#根据文件名生成一个队列
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) #返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'img': tf.FixedLenFeature([], tf.string),
'img_label' : tf.FixedLenFeature([], tf.string),
'diff': tf.FixedLenFeature([], tf.string),
'diff_label' : tf.FixedLenFeature([], tf.string),
})
self.img = tf.decode_raw(features['img'], tf.uint8)
self.img=tf.reshape(self.img,[192,256,3])
self.diff=tf.decode_raw(features['diff'], tf.uint8)
self.diff=tf.reshape(self.diff,[192,256,3])
self.img_label = tf.decode_raw(features['img_label'], tf.uint8)
self.img_label=tf.reshape(self.img_label,[192,256,1])
self.diff_label=tf.decode_raw(features['diff_label'], tf.uint8)
self.diff_label=tf.reshape(self.diff_label,[192,256,1])
self.diff_label=tf.concat([self.diff_label,self.diff_label,self.diff_label],-1)
self.diff_label=self.diff_label*225
#img = tf.reshape(img, [224, 224, 3])
#img = tf.cast(img, tf.float32) * (1. / 255) - 0.5
#label = tf.cast(features['label'], tf.int32)
def get_batch_tf(self,batch_size):
img,diff,img_label,diff_label=tf.train.batch([self.img,self.diff,self.img_label,self.diff_label],batch_size=batch_size,capacity=batch_size*11,num_threads=4)#,shapes=[(192, 256, 3),(192, 256, 3),(192, 256,1),(192, 256,3)]
return img,diff,img_label,diff_label
if __name__=='__main__':
pro=Image_data()
pro.read_and_decode_tf("I:/transfer/plan9_small_11/train.tfrecords")
img,diff,img_label,diff_label=pro.get_batch_tf(30)
sess=tf.InteractiveSession()
coord = tf.train.Coordinator()
tf.train.start_queue_runners(sess=sess,coord=coord)
for i in range(20):
img_,diff_ ,img_label_,diff_label_= sess.run([img,diff,img_label,diff_label])
for j in range(30):
cv2.imshow('img',img_[j])
cv2.imshow('diff',diff_[j])
cv2.imshow('diff_label',diff_label_[j])
cv2.waitKey(0)
coord.request_stop()
coord.join(threads)
sess.close()
# pro=Image_data()
# pro.read_and_decode()
# img,diff=pro.get_batch(30)
# sess=tf.InteractiveSession()
# tf.train.start_queue_runners(sess=sess)
# for i in range(20):
# for j in range(30):
# img_,diff_ = sess.run([img,diff])
# cv2.imshow('img',img_[j])
# cv2.imshow('diff',diff_[j])
# cv2.waitKey(0)