TensorLayer出现数据形状转换错误

参考链接:

TensorLayer/tutorial_cifar10_cnn_static.py at master · tensorlayer/TensorLayer (github.com)

1.错误部分

使用TensorLayer来进行CIFAR-10 数据集上的图像分类,直接运行源文件的时候出现了如下错误:

InvalidArgumentError: Input to reshape is a tensor with 128 values, but the requested shape has 1
	 [[{{node Reshape}}]]

出现错误部分的代码如下:

def _map_fn_train(img, target):
    # 1. Randomly crop a [height, width] section of the image.
    img = tf.image.random_crop(img, [24, 24, 3])
    # 2. Randomly flip the image horizontally.
    img = tf.image.random_flip_left_right(img)
    # 3. Randomly change brightness.
    img = tf.image.random_brightness(img, max_delta=63)
    # 4. Randomly change contrast.
    img = tf.image.random_contrast(img, lower=0.2, upper=1.8)
    # 5. Subtract off the mean and divide by the variance of the pixels.
    img = tf.image.per_image_standardization(img)
    target = tf.reshape(target, ())
    return img, target
    
# dataset API and augmentation
train_ds = tf.data.Dataset.from_generator(
    generator_train, output_types=(tf.float32, tf.int32)
)  # , output_shapes=((24, 24, 3), (1)))
# train_ds = train_ds.repeat(n_epoch)
train_ds = train_ds.shuffle(shuffle_buffer_size)
train_ds = train_ds.prefetch(buffer_size=4096)
train_ds = train_ds.batch(batch_size)
train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())

2.具体分析

问题的具体描述是出现了数据形状转换错误,而通过排查可以确定问题就出现在map进行的数据转换上,其具体原因为train_ds在进行map转换时,首先进行了batch操作,将数据集转化为了小批量数据的格式,而map函数进行操作时的操作对象是单一的数据,因此数据格式出现了冲突,导致了该问题的发生。

3.解决办法

我们需要在进行batch前先进行map操作,完成转换后再进行小批量处理。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值