tensorflow采坑记录之tf.nn.conv2d_transpose

tensorflow采坑记录之conv2d_transpose

在使用tensorflow对图像进行处理(诸如图像自编码器、图像分割、图像超分辨率、图像融合等图像生成过程)时难免会遇到下(降)采样然后上采样的情况,虽然大家都会尝试使用conv2d_transpose也就是反卷积来实现上采样,但是关于在使用conv2d_transpose时遇见的几个坑在这里特此记录一下。
首先给出官方文档中关于conv2d_transpose的介绍。

tf.nn.conv2d_transpose(
    value=None,
    filter=None,
    output_shape=None,
    strides=None,
    padding='SAME',
    data_format='NHWC',
    name=None,
    input=None,
    filters=None,
    dilations=None
)
value: A 4-D Tensor of type float and shape [batch, height, width, in_channels] for NHWC data format or [batch, in_channels, height, width] for NCHW data format.
filter: A 4-D Tensor with the same type as value and shape [height, width, output_channels, in_channels]. filter's in_channels dimension must match that of value.
output_shape: A 1-D Tensor representing the output shape of the deconvolution op.
strides: An int or list of ints that has length 1, 2 or 4. The stride of the sliding window for each dimension of input. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 0. The dimension order is determined by the value of data_format, see below for details.
padding: A string, either 'VALID' or 'SAME'. The padding algorithm. See the "returns" section of tf.nn.convolution for details.
data_format: A string. 'NHWC' and 'NCHW' are supported.
name: Optional name for the returned tensor.
input: Alias for value.
filters: Alias for filter.
dilations: An int or list of ints that has length 1, 2 or 4, defaults to 1. The dilation factor for each dimension ofinput. If a single value is given it is replicated in the H and W dimension. By default the N and C dimensions are set to 1. If set to k > 1, there will be k-1 skipped cells between each filter element on that dimension. The dimension order is determined by the value of data_format, see above for details. Dilations in the batch and depth dimensions if a 4-d tensor must be 1.

接下来接受一下自己遇到的坑。首先是如何确定上采样后的feature map size,通过查阅资料得知,在定了graph时可以将output_shape设置为一个tensor,即可以使用tf.shape(),将你期望与之相似的feature放进tf.shape(),这样使用conv2d_transpose输出的size就会与feature一致,这样在使用密集连接之类的操作时就不会 出bug。
其次是另外一个坑,在conv2d_transpose中其filter中的参数为 [kernel_size, kerenel_size,output_channel, input_channel],而在conv2d中filter却表示为:[kernel_size, kerenel_size,input_channel, output_channel], 一不小心就会把conv2d_transpose中的参数写错,这不报错才怪,除非你输入输出的channel一样。
这里也将自己的错误操作记录一下:

with tf.compat.v1.variable_scope('upsample1'):
    weights = tf.get_variable("w", [3, 3, 64, 128], initializer=tf.truncated_normal_initializer(
        stddev=1e-3))
    x1_upsample = tf.nn.conv2d_transpose(value=x1_merge, filter=weights, output_shape=tf.shape(
        feature_vi), strides=[1, 2, 2, 1], padding='SAME')
    x1_upsample = lrelu(x1_upsample)

我输入的feature channel是64维,期望输出的channel是128维,但是像上面这样写的话就会报错:

ValueError: Incompatible shapes between op input and calculated input gradient.  Forward operation: fusion/FPAF_Model/Upsample/upsample1/conv2d_transpose.  Input index: 2. Original input shape: (32, 8, 8, 64).  Calculated input gradient shape: (32, 8, 8, 128)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimension 3 in both shapes must be equal, but are 128 and 64. Shapes are [32,8,8,128] and [32,8,8,64].

反正就一大堆惨不忍睹的bug,简直令人崩溃。
只要稍稍修改就能正常运行了。

with tf.compat.v1.variable_scope('upsample1'):
    weights = tf.get_variable("w", [3, 3, 128, 64], initializer=tf.truncated_normal_initializer(
        stddev=1e-3))
    x1_upsample = tf.nn.conv2d_transpose(value=x1_merge, filter=weights, output_shape=tf.shape(
        feature_vi), strides=[1, 2, 2, 1], padding='SAME')
    x1_upsample = lrelu(x1_upsample)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Timer-419

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值