1.写入tfrecord
import tensorflow as tf
import os
import numpy as np
import PIL.Image as Image
def _get_path_label(image_dir):
image_dir = os.path.expanduser(image_dir)
ford_list = []
for ford in os.listdir(image_dir):
for sub_ford in os.listdir(os.path.join(image_dir, ford)):
ford_list.append((ford + "\\" + sub_ford))
ids = ford_list # [0:50000]#list(os.listdir(image_dir))
ids.sort()
cat_num = len(ids)
# logger.info("the total people number is {}".format(cat_num))
id_dict = dict(zip(ids, list(range(cat_num))))
paths = []
labels = []
for i in ids:
cur_dir = os.path.join(image_dir, i)
fns = os.listdir(cur_dir)
paths.extend([os.path.join(cur_dir, fn) for fn in fns])
labels.extend([id_dict[i]] * len(fns))
_perm = np.random.permutation(np.arange(len(paths)))
# _perm = np.arange(len(paths))
shuffle_paths = []
shuffle_labels = []
for i in range(len(paths)):
shuffle_paths.append(paths[_perm[i]])
shuffle_labels.append(labels[_perm[i]])
return shuffle_paths, shuffle_labels, cat_num
def _byteslist(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _int64list(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def creat_train_record(train_dir , train_record_path):
writer = tf.compat.v1.python_io.TFRecordWriter(train_record_path)
shuffle_paths, shuffle_labels, cat_num = _get_path_label(train_dir)
for index, image_name in enumerate(shuffle_paths):
img = Image.open(image_name)
img_raw = img.tobytes()
example = tf.train.Example(
features=tf.train.Features(feature={
'label': _int64list(shuffle_labels[index]),
'img_raw': _byteslist(img_raw)}))
writer.write(example.SerializeToString())
if index % 1000 == 0:
print("current i is:", index, " all data is:", len(shuffle_labels))
writer.close()
print('creat_train_record success !')
print("cat_num:", cat_num)
creat_train_record(r'E:\dataset\face\train\temp', r'E:\dataset\face\traindata.tfrecords')
2.read_tfRecord
import tensorflow as tf
import os
from PIL import Image
import matplotlib.pyplot as plt
class ReadTfRecord:
def __init__(self, filename, batch_size):
self.filename = filename
self.cat_num = 1000
#filename_queue = tf.train.string_input_producer([filename])
reader = tf.data.TFRecordDataset(filename)
reader = reader.repeat(1)
self.features={
'label': tf.io.FixedLenFeature([], tf.int64),
'img_raw': tf.io.FixedLenFeature([], tf.string)}
reader = reader.map(self._parse_function) # 解析数据
self.batch = reader.batch(batch_size=batch_size) # 每10条数据为一个batch,生成一个新的Dataset
def _parse_function(self, exam_proto):
return tf.io.parse_single_example(exam_proto, self.features)
def get_data(self, item):
label_shape = item['label']
data_batch = item['img_raw']
# for data in data_batch:
img = tf.io.decode_raw(data_batch, tf.uint8)
img = tf.reshape(img, (-1, 128, 128, 3))
img = tf.cast(img, tf.float32) / 255.0
return img, label_shape
#TFRecord = ReadTfRecord(r'E:\dataset\face_recognition_train\data.tfrecords', 96)
# for i in range(5):
# for item in TFRecord.batch:
#
# img, label = TFRecord.get_data(item)
#
#
# print(img.shape, " ", label.shape)