0 前言
根据文档中对于api介绍, tf.depth_to_space
与torch中pixelshuffle
功能应该是一致的,
都是把深度维的数据给move到宽度和高度上。如对于一个输入(1,12, 4,4), 上采样倍数为2时,
输出为(1,3,8,8),可以看到, 输出通道降低了4倍, 宽和高分别扩大了2倍。
然而实际测试发现了不一致的情况。
1 现象
当输出通道不为1的时候, 两者的结果不一致, 当输出通道为1的时候, 两者的结果一致。
试验代码如下:
import tensorflow.compat.v1 as tf
import torch
from torch.nn import functional as F
import numpy as np
def _tf_pixshuffle(input, up_scale):
return tf.nn.depth_to_space(tf.constant(input), up_scale, data_format='NCHW').numpy()
def _torch_pixshuffle(input, up_scale):
return F.pixel_shuffle(torch.Tensor(input), up_scale).numpy()
# 1 先来一个维度小点的, 可以直接比较数值
test_data1 = [[[[1, 2, 3, 15], [4, 5, 6, 16]],
[[7, 8, 9, 17], [10, 11, 12, 13]]]]
test_data1 = np.array(test_data1)
test_data1 = test_data1.transpose(0, 3, 1, 2) # NCHW
out_tf1 = _tf_pixshuffle(test_data1, 2)
out_torch1 = _torch_pixshuffle(test_data1, 2)
print(f"input shape:{test_data1.shape}, equal: {(out_tf1==out_torch1).all()}")
# 2 来个维度大点的,W和H相等, 注意用整数, 因为小数会有精度误差, 不好比较
test_data2 = np.random.randint(0, 100, size=(1, 4, 12, 12))
out_tf2 = _tf_pixshuffle(test_data2, 2)
out_torch2 = _torch_pixshuffle(test_data2, 2)
print(f"input shape:{test_data2.shape}, equal: {(out_tf2==out_torch2).all()}")
# 3 来个维度大点的,W和H不相等, 注意用整数, 因为小数会有精度误差, 不好比较
test_data3 = np.random.randint(0, 100, size=(1, 4, 12, 16))
out_tf3 = _tf_pixshuffle(test_data3, 2)
out_torch3 = _torch_pixshuffle(test_data3, 2)
print(f"input shape:{test_data3.shape}, equal: {(out_tf3==out_torch3).all()}")
# 4 来个维度大点的,W和H不相等,通道大于1 注意用整数, 因为小数会有精度误差, 不好比较
test_data4 = np.random.randint(0, 100, size=(1, 48, 12, 16))
out_tf4 = _tf_pixshuffle(test_data4, 2)
out_torch4 = _torch_pixshuffle(test_data4, 2)
print(f"input shape:{test_data4.shape}, equal: {(out_tf4==out_torch4).all()}")
运行结果如下:
input shape:(1, 4, 2, 2), equal: True
input shape:(1, 4, 12, 12), equal: True
input shape:(1, 4, 12, 16), equal: True
input shape:(1, 48, 12, 16), equal: False
2 解决方法
自定义pixshuffle功能, 下述定义与torch的pixelshuffle
功能等价, 但是速度比tf原本的depth_to_space
慢很多。
def _pixel_shuffle(input, upscale_factor):
batch_size, channels, in_height, in_width = input.shape.as_list()
new_channels = channels //(upscale_factor*upscale_factor)
out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
input_reshape = tf.reshape(input, [batch_size, new_channels, upscale_factor, upscale_factor,in_height, in_width])
shuffle_out = tf.transpose(input_reshape,perm =[0, 1, 4, 2, 5, 3] )
out = tf.reshape(shuffle_out, [batch_size, new_channels, out_height, out_width] )
return out
为了提升速度, 还是得采用depth_to_space
接口, 只是需要进行一定的改造。
通过上面的观察我们发现, 既然输出通道为1时, 输出是正确的, 那么是否可以拆分原本输入的通道,每次只计算一个输出通道为1的数据块, 然后把计算的结果给concat起来呢?
def _tf_pixelshuffle_cat(input, upscale_factor):
input = tf.constant(input)
temp = []
depth = upscale_factor *upscale_factor
channels = input.shape.as_list()[1] // depth
for i in range(channels):
out_ = tf.nn.depth_to_space(input=input[:,i*depth:(i+1)*depth, :,:], block_size=upscale_factor, data_format="NCHW")
temp.append(out_)
out = tf.concat(temp, axis=1)
out = out.numpy()
return out
通过实验发现, 这个结果跟torch的pixelshuffle
的输出是一致的, 速度上也跟tf原本的depth_to_space
差不多。 至此, 问题完美解决。
3 参考
网上也有其他用户发现了类似现象的
[1] Is Depth_to_space function totally different to Pixelshuffle? https://discuss.tensorflow.org/t/is-depth-to-space-function-totally-different-to-pixelshuffle/3086
[2] TF depth_to_space not same as Torch’s PixelShuffle when output channels > 1? https://stackoverflow.com/questions/68272502/tf-depth-to-space-not-same-as-torchs-pixelshuffle-when-output-channels-1