tf depth_to_space 与torch pixshuffle 踩坑

0 前言

根据文档中对于api介绍, tf.depth_to_space与torch中pixelshuffle功能应该是一致的,
都是把深度维的数据给move到宽度和高度上。如对于一个输入(1,12, 4,4), 上采样倍数为2时,
输出为(1,3,8,8),可以看到, 输出通道降低了4倍, 宽和高分别扩大了2倍。
然而实际测试发现了不一致的情况。

1 现象

当输出通道不为1的时候, 两者的结果不一致, 当输出通道为1的时候, 两者的结果一致。

试验代码如下:

import tensorflow.compat.v1 as tf
import torch
from torch.nn import functional as F
import numpy as np


def _tf_pixshuffle(input, up_scale):
    return tf.nn.depth_to_space(tf.constant(input), up_scale, data_format='NCHW').numpy()


def _torch_pixshuffle(input, up_scale):
    return F.pixel_shuffle(torch.Tensor(input), up_scale).numpy()


# 1 先来一个维度小点的, 可以直接比较数值
test_data1 = [[[[1, 2, 3, 15], [4, 5, 6, 16]],
               [[7, 8, 9, 17], [10, 11, 12, 13]]]]

test_data1 = np.array(test_data1)
test_data1 = test_data1.transpose(0, 3, 1, 2)  # NCHW

out_tf1 = _tf_pixshuffle(test_data1, 2)
out_torch1 = _torch_pixshuffle(test_data1, 2)

print(f"input shape:{test_data1.shape}, equal: {(out_tf1==out_torch1).all()}")


# 2 来个维度大点的,W和H相等, 注意用整数, 因为小数会有精度误差, 不好比较
test_data2 = np.random.randint(0, 100, size=(1, 4, 12, 12))
out_tf2 = _tf_pixshuffle(test_data2, 2)
out_torch2 = _torch_pixshuffle(test_data2, 2)
print(f"input shape:{test_data2.shape}, equal: {(out_tf2==out_torch2).all()}")

# 3 来个维度大点的,W和H不相等, 注意用整数, 因为小数会有精度误差, 不好比较
test_data3 = np.random.randint(0, 100, size=(1, 4, 12, 16))
out_tf3 = _tf_pixshuffle(test_data3, 2)
out_torch3 = _torch_pixshuffle(test_data3, 2)
print(f"input shape:{test_data3.shape}, equal: {(out_tf3==out_torch3).all()}")


# 4 来个维度大点的,W和H不相等,通道大于1 注意用整数, 因为小数会有精度误差, 不好比较
test_data4 = np.random.randint(0, 100, size=(1, 48, 12, 16))
out_tf4 = _tf_pixshuffle(test_data4, 2)
out_torch4 = _torch_pixshuffle(test_data4, 2)
print(f"input shape:{test_data4.shape}, equal: {(out_tf4==out_torch4).all()}")

运行结果如下:

input shape:(1, 4, 2, 2), equal: True
input shape:(1, 4, 12, 12), equal: True
input shape:(1, 4, 12, 16), equal: True
input shape:(1, 48, 12, 16), equal: False

2 解决方法

自定义pixshuffle功能, 下述定义与torch的pixelshuffle功能等价, 但是速度比tf原本的depth_to_space慢很多。

def _pixel_shuffle(input, upscale_factor):
    batch_size, channels, in_height, in_width =  input.shape.as_list()
    new_channels = channels //(upscale_factor*upscale_factor)
    out_height = in_height * upscale_factor
    out_width = in_width * upscale_factor

    input_reshape = tf.reshape(input, [batch_size, new_channels, upscale_factor, upscale_factor,in_height, in_width])
    shuffle_out = tf.transpose(input_reshape,perm =[0, 1, 4, 2, 5, 3] )
    out = tf.reshape(shuffle_out, [batch_size, new_channels, out_height, out_width] )
    return out

为了提升速度, 还是得采用depth_to_space接口, 只是需要进行一定的改造。
通过上面的观察我们发现, 既然输出通道为1时, 输出是正确的, 那么是否可以拆分原本输入的通道,每次只计算一个输出通道为1的数据块, 然后把计算的结果给concat起来呢?

def _tf_pixelshuffle_cat(input, upscale_factor):
    input = tf.constant(input)
    temp = []
    depth = upscale_factor *upscale_factor
    channels = input.shape.as_list()[1] // depth
    for i in range(channels):
        out_ = tf.nn.depth_to_space(input=input[:,i*depth:(i+1)*depth, :,:], block_size=upscale_factor, data_format="NCHW")
        temp.append(out_)
    out = tf.concat(temp, axis=1)
    out = out.numpy()
    return out

通过实验发现, 这个结果跟torch的pixelshuffle的输出是一致的, 速度上也跟tf原本的depth_to_space差不多。 至此, 问题完美解决。

3 参考

网上也有其他用户发现了类似现象的
[1] Is Depth_to_space function totally different to Pixelshuffle? https://discuss.tensorflow.org/t/is-depth-to-space-function-totally-different-to-pixelshuffle/3086

[2] TF depth_to_space not same as Torch’s PixelShuffle when output channels > 1? https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值