问题
项目中需要把每张图片的标记向量存入TFrecords文件,参考了图片的存储方法,但会报错,经过几个小时的折腾才发现问题在哪里(汗颜)
代码
tfrecords_filename = './train2.tfrecords'
writer = tf.python_io.TFRecordWriter(tfrecords_filename)
#重点在于指定整数的类型
img_raw = np.array([12,14,13,15,16,17],dtype=np.int32)
img_raw = img_raw.tostring()
example = tf.train.Example(features=tf.train.Features(feature={
'label':tf.train.Feature(int64_list = tf.train.Int64List(value=[5])),
'img_raw':tf.train.Feature(bytes_list = tf.train.BytesList(value=[img_raw]))}))
writer.write(example.SerializeToString())
writer.close()
filename_queue = tf.train.string_input_producer([tfrecords_filename]) # 读入流中
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue) # 返回文件名和文件
features = tf.parse_single_example(serialized_example,
features={
'label': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string),
}) # 取出包含image和label的feature对象
image = tf.decode_raw(features['img_raw'], tf.int32)
image = tf.reshape(image, [6])
label = tf.cast(features['label'], tf.int64)
with tf.Session() as sess: # 开始一个会话
init_op = tf.initialize_all_variables()
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
example, l = sess.run([image, label])
print(example, l)
coord.request_stop()
coord.join(threads)
在数据读出的时候,默认的整数类型为int32,之前都用的int64,发现输出总是少一半数据,换成int16后又多一半数据,因此,显式指定数据类型是很有必要的。
感谢这篇文章提供的总体框架,总算可以进行下去了~