tfrecord在读取时要注意读取后的数据必须与tfrecord写入之前的数据类型保持一致,否则会出现常见的两种错误:
一、无法reshape
ValueError: Cannot reshape a tensor with 1 elements to shape [256,256,4] (262144 elements) for 'Reshape' (op: 'Reshape') with input shapes: [], [3] and with input tensors computed as partial shapes: input[1] = [256,256,4].
这是由于在使用 tf.parse_single_example解析example时的数据类型错误,我的解析函数如下,解析出来的数据类型为str,而str类型的数据是无法进行reshape的
features = tf.parse_single_example(
record,
features={'label': tf.FixedLenFeature([], tf.string),
'image_raw': tf.FixedLenFeature([], tf.string)})
二、reshape维度错误
由于tfrecord读取后的数据与读取前的数据类型不同导致reshape维度错误
tensorflow.python.framework.errors_impl.InvalidArgumentError: Input to reshape is a tensor with 131072 values, but the requested shape has 262144
[[{{node Reshape}}]]
[[node IteratorGetNext (defined at /Users/‘’‘/PycharmProjects/test/unet_write_to_tfrecord.py:140) ]]
解决方法
这时可以通过tf.decode_raw这个函数来修改数据类型,我的写入tfrecord前的数据为256*256*4的uint16 array,所以在tfrecord读取时使用tf.decode_raw将数据类型改为uint16
image = tf.decode_raw(features['image_raw'], tf.uint16)