#写入tfrecord
def create_tf_record(inputs, labels, tfrecords_filename):
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
for input, in label in zip(inputs, labels):
# 开始存入一个数据
raw = np.array(inputs).tostring()
labels = np.array(human_labels, dtype=np.int64).tostring()
example = tf.train.Example(features=tf.train.Features(
feature={
'label': tf.train.Feature(bytes_list=tf.train.Int64List(value=[labels])),
'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw]))
}))
writer.write(example.SerializeToString())
writer.close()
# 使用DataSet按batch读取
def get_batch(batchSize=10):
# 解析tfrecord
def parser(record):
features = tf.parse_single_example(record,
features={
'label': tf.FixedLenFeature([], tf.int64),
'data': tf.FixedLenFeature([], tf.string),
})
data = tf.decode_raw(features['data'], tf.float32)
label = tf.decode_raw(features['label'], tf.int64)
# 调整数据的shape
data = tf.reshape(data, [40, 40, 3])
label = tf.reshape(label, [10])
return data, label
# 文件路径
tfrecords_filenames = ['../tdrecord/1.tfrecords', '../tdrecord/2.tfrecords']
dataset = tf.data.TFRecordDataset(tfrecords_filenames)
dataset = dataset.map(parser)
# dataset队列中保持有500个数据,对数据做100次重复,每次产生32个数据
dataset = dataset.shuffle(500).repeat(100).batch(32)
# 产生数据
iterator = dataset.make_one_shot_iterator()
data, label = iterator.get_next()
return data, label