问题背景:用mobilenetv2提取某张图像的特征,结果提示:
原始代码:
train_data = tf.data.Dataset.from_tensor_slices((X_in))
train_data = train_data.map(mobi_parse_fun)
train_data = train_data.batch(1) # batch算一个维度
解决办法:
train_data = tf.data.Dataset.from_tensor_slices((X_in))
train_data = train_data.map(mobi_parse_fun).repeat().batch(1)
输出: