记录一个在使用tensorflow把数据存储为tfr文件,之后在从tfr文件中读取时遇到的问题。
问题是:
InvalidArgumentError: Input to reshape is a tensor with 2944256 values, but the requested shape has 1472128 [[{{node Reshape_1}}]] [Op:IteratorGetNext]
这里面我做了一个reshape操作,它的意思是,我的tensor一共有2944256个value,但是我reshape后的tensor只有1472128个值,正好少了一半的数据没了。
出现这个问题的代码是这样的:
- 首先,我要把
dncVec
和target
这两个ndarray类型的多维数组,存为example中的值。
问题就出在target
,我们只看target,target是我用随机矩阵,随机生成的,代码为:
target = np.random.rand(1,896,1643)
整个写入过程都是没问题的。
seqLength = 131072
def data_generate(path):
#写入
with tf.io.TFRecordWriter(path) as writer:
#生成255个example
for i in range(5):
dna = randomSeq(seqLength)#dna序列
dnaVec = one_hot_encode(dna)#131072,4的矩阵
dnaVec = np.expand_dims(dnaVec,axis=0)#1,131072,4的矩阵
target = np.random.rand(1,896,1643)
#转成字节
dnaVec = dnaVec.tobytes()
target = target.tobytes()
feature = {
#序列使用的是tf.train.BytesList类型
'sequence':tf.train.Feature(bytes_list=tf.train.BytesList(value=[dnaVec])),
'target':tf.train.Feature(bytes_list=tf.train.BytesList(value=[target]))
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
writer.write(example.SerializeToString())
我们先看下np.random.rand()生成的矩阵:
这没问题,然后看下a的元素类型
>>> a.dtype
dtype('float64')
>>>
矩阵的内部元素的类型是float64
。
接着是读取tfr文件:
def parse_example(example_string):
#解析之后得到的example
example = tf.io.parse_single_example(example_string,feature_description)
#example['sequence']还是字节流的形式,重新转为数字向量
sequence = tf.io.decode_raw(example['sequence'], tf.float32)
sequence = tf.reshape(sequence,(1,seqLength,4)) #形状需要重塑,不然就是一个长向量
target = tf.io.decode_raw(example['target'], tf.float32)
target = tf.reshape(target,(1,896,1643))
#把整个字典返回
return {
'sequence':sequence,
'target': target
}
关键是这句代码:
target = tf.io.decode_raw(example['target'], tf.float32)
tf.io.decode_raw(
input_bytes, out_type, little_endian=True, fixed_length=None, name=None
)
tf.io.decode_raw()
将输入张量的原始字节转换成数字张量。
输入的是字节序列,然后把这些字节解码为 out_type
指定格式的数字。
原始数据是什么格式这里解析必须是什么格式,要不然会出现形状的不对应问题!
这里指定的是tf.float32
,但是之前随机生成矩阵中的元素类型是float64
,这就是问题所在。为什么会这样我也不清楚,但是只要把 target = tf.io.decode_raw(example['target'], tf.float32)
改为tf.float64
,就不会出现这个问题了。
或者提前把矩阵内元素的类型做一下变换,
target = target.astype(np.float32)
后面代码就不需要改变了。
拿出一个example
总之,问题应该就是矩阵元素的类型变换出错了。