记录一下最近遇到的问题
特别注意,在制作tfrercord的时候,写入的数据类型要和读出的时候保持一致,不然会出现多个矩阵的问题。
的萨芬
比如现有一个(100,100)的精度很高的矩阵,他的元素数据类型是float64,制作时是这个类型写进tfrecord,
但是在读出tfrecord的时候你用的类型是tf.float32,那么就会原来的写入的一个矩阵读出来会变成两个矩阵,应该是由于精度的问题,
data = tf.io.decode_raw(features['data'], out_type= tf.float64) #如果这里的out_type=换成tf.float32的话
- TF2里面的tfrecord数据格式制作与加载
制作没什么好讲的,代码如下:
def _make_tfrecords(doc_path, save_dir, train_or_test):
tfrecords_path = save_dir +'\\' +train_or_test +'.tfrecord'
if os.path.exists(tfrecords_path):
os.remove(tfrecords_path)
print('clean origional tfrecords')
def _bytes_feature(values):
return tf.train.Feature(bytes_list= tf.train.BytesList(value= [values]))
def _int_feature(values):
return tf.train.Feature(int64_list = tf.train.Int64List(value= [values]))
writer = tf.io.TFRecordWriter(tfrecords_path)
for i, single_image_path in enumerate(traverse_document(doc_path)):
# img = np.array(Image.open(img_path))
# 这里得到一个图片的真实标签,另一个标签tag_label.图片名字,图片数组,宽以及高
tru_label, tag_label, img_name, img, width, height= from_name_get_true_and_target_label(single_image_path)
# 字符串编码为二进制
img_name = img_name.encode('utf-8')
example = tf.train.Example(
features = tf.train.Features(
feature = {
'img_name':