继上一篇博文,写到如何将非string类型的数组写入tfrecords文件,写入文件后读取有掉进坑里了,读出来的格式不正确,由于读frecords代码依然是照抄示例,将图片转成string写入tfrecords文件,然后再按照string的类型通过decode_raw将读取的内容转译回数组,
def _parse_function(example_proto):
features = {'image':tf.FixedLenFeature([], tf.string),
'label':tf.FixedLenFeature([], tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
img = tf.decode_raw(parsed_features['image'], tf.uint8)
img = tf.reshape(img, [128, 128, 1])
# 在流中抛出img张量和label张量
img = tf.cast(img, tf.float32) / 255
label = tf.cast(parsed_features['label'], tf.int64)
return img, label
但是我的图片数据写入并非按照string的格式写入的,而是以int64_list的格式写入的:
example = tf.train.Example(features=tf.train.Features(feature={
"image": tf.train.Feature(int64_list=tf.train.Int64List(value=newDataFlat)),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))
读取的时候就有问题,继续查如何读写tfrecords文件,依然是我照抄的样例的方式,不能解决我的问题。那就根据调试器报错一步一步的解决错误,却发现已经有贤者遇到并解决此类问题,由于我对读写tfrecords文件的函数只是一知半解,根本不知道还有可能会存在错误,根据大神的提示,才明白用法,而查函数调用方法根本就没有任何帮助,
解决方法如下,
方案一:如果写入tfrecords文件时的数据数组是固定大小的,用tf.FixedLenFeature()读取的时候需要指定数据数组的长度,例如:
def _parse_function(example_proto):
features = {'image':tf.FixedLenFeature([10000], tf.int64),
'label':tf.FixedLenFeature([], tf.int64)}
...
这样就能够将相同长度写入tfrecords文件的数据读取出来。
方案二:
不用tf.FixedLenFeature(),而是用tf.VarLenFeature(tf.int64)
def _parse_function(example_proto):
features = {'image':tf.VarLenFeature( tf.string),
'label':tf.VarLenFeature(tf.int64)}
parsed_features = tf.parse_single_example(example_proto, features)
data = tf.sparse_tensor_to_dense(features['data'], default_value=0)
data = tf.reshape(data, [20,50])
...
根据这个方法,又查tf.sparse_tensor_to_dense()函数的作用,是将稀疏矩阵转换为稠密矩阵。
what?稀疏矩阵,函数接口没看到说明返回结果是个系数矩阵啊,然后继续查,有大神介绍说用该方法读取的tfrecords文件确实是稀疏矩阵,新手出坑挺难啊,处处是坑。
参考:
https://blog.csdn.net/qq_45391763/article/details/103562357