import os
import tensorflow as tf
import cv2
import pickle
import matplotlib.pyplot as plt
%matplotlib inline
#note:download cifar10 dataset first
def load_data(path):
with open(path,'rb') as f:
#data = pickle.load(f,encoding='bytes')
#x = data[b'data']
#y = data[b'labels']
# or
data = pickle.load(f,encoding='latin1')
x = data['data']
y = data['labels']
#x = np.reshape(x,[-1,3,32,32])
x = np.reshape(x,[-1,3*32*32]).astype('float32')/255.
y = np.array(y).astype('int64')
return x,y
def load_train(root='cifar10/data_batch_'):
xs = []
ys = []
for i in range(1,6):
x,y = load_data(root+str(i))
xs.append(x)
ys.append(y)
train_x = np.concatenate(xs)
train_y = np.concatenate(ys)
return train_x,train_y
def load_test():
return load_data('cifar10/test_batch')
def create_tfRecords(name='train'):
if name=='train':
x,y = load_train()
writer = tf.python_io.TFRecordWriter('tfRecords/train.tfrecords')
else:
x,y = load_test()
writer = tf.python_io.TFRecordWriter('tfRecords/test.tfrecords')
for i in range(y.shape[0]):
img,label = x[i].tobytes(),y[i].tobytes() #to byte
example = tf.train.Example(features=tf.train.Features(feature={
'img':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),
'label':tf.train.Feature(bytes_list=tf.train.BytesList(value=[label]))
}))
writer.write(example.SerializeToString())
#create tfrecords
create_tfRecords()
create_tfRecords('test')
#read tfrecords
def _parse(example):
features = {
'img':tf.FixedLenFeature((),tf.string),
'label':tf.FixedLenFeature((),tf.string)
}
parse_example = tf.parse_single_example(example,features)
img = parse_example['img'] #byte
label = parse_example['label']#byte
img = tf.decode_raw(img,out_type=tf.float32)
img = tf.reshape(img,[3,32,32])
img = tf.transpose(img,[1,2,0])
label = tf.decode_raw(label,tf.int64)
label = tf.reshape(label,[]) #scalar
return img,label
def load_tfRecords(name='train'):
file = 'tfRecords/'+name+'.tfrecords'
ds = tf.data.TFRecordDataset(file)
ds = ds.map(_parse)
ds = ds.shuffle(1024)
ds = ds.batch(32)
ds = ds.repeat()
it = ds.make_one_shot_iterator()
next_data = it.get_next()
return next_data
train_next = load_tfRecords()
#test
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
x,y = sess.run(train_next)
plt.imshow(x[0])
plt.show()
cifar10+tfrecords
最新推荐文章于 2020-12-29 00:06:12 发布