【torch.nn.PixelShuffle】和 【torch.nn.UnpixelShuffle】

torch.nn.PixelShuffle

直观解释

PixelShuffle是一种上采样方法,它将形状为 ( ∗ , C × r 2 , H , W ) (∗, C\times r^2, H, W) (,C×r2,H,W)的张量重新排列转换为形状为 ( ∗ , C , H × r , W × r ) (∗, C, H\times r, W\times r) (,C,H×r,W×r)的张量:
在这里插入图片描述


举个例子
输入的张量大小是(1,8,2,3),PixelShuffle的 缩放因子是r=2

import torch
ps=torch.nn.PixelShuffle(2)
input=torch.arange(0,48).view(1,8,2,3)
print(input)
output=ps(input)
print(output)
print(output.shape)

如下图可以看到,PixelShuffle是把输入通道按照缩放因子r^2进行划分成8/(2^2)=2 组。
也就是输入的第一组(前4个通道)中的元素,每次间隔r=2 交错排列,合并成输出的第一个通道维度。
输入的第二组(后4个通道)中的元素,每次间隔r=2交错排列,合并成输出通道的第二个维度。
输入的大小为(batchsize,in_channel,in_height,in_width)=(1,8,2,3)
输出的大小为(batchsize,out_channel,out_height,out_width)(1,2,4,6)

各个维度的变化规律如下:
batchsize 不变;
out_channel=in_channel/(r^2)
out_height=in_height*r
out_width=in_width*r
在这里插入图片描述

官方文档

CLASS
torch.nn.PixelShuffle(upscale_factor)
  • 功能: 把大小为 ( ∗ , C × r 2 , H , W ) (*,C\times r^2,H,W) (,C×r2,H,W)的张量重新排列为大小为 ( ∗ , C , H × r , W × r ) (*,C,H\times r,W\times r) (,C,H×r,W×r) , 其中 r r r 是 upscale factor 。

    这个操作对于实现步长为 1 r \frac {1}{r} r1efficient sub-pixel convolution有用。

  • 参数

    • upscale_factor(int) : 增加空间分辨率的因子
  • 形状

    • 输入: ( ∗ , C i n , H i n , W i n ) (*,C_{in},H_{in},W_{in}) (,Cin,Hin,Win) ,其中 ∗ * 是 0 或者batch大小

    • 输出: ( ∗ , C o u t , H o u t , W o u t ) (*,C_{out},H_{out},W_{out}) (,Cout,Hout,Wout) , 其中

      C out  = C in  ÷ u p s c a l e _ f a c t o r 2 H out  = H in  × u p s c a l e _ f a c t o r W out  = W in  × u p s c a l e _ f a c t o r C_{\text {out }}=C_{\text {in }} \div upscale\_factor ^2 \\ H_{\text {out }}=H_{\text {in }} \times upscale\_factor \\ W_{\text {out }}=W_{\text {in }} \times upscale\_factor Cout =Cin ÷upscale_factor2Hout =Hin ×upscale_factorWout =Win ×upscale_factor

  • 例子

>>> pixel_shuffle = nn.PixelShuffle(3)
>>> input = torch.randn(1, 9, 4, 4)
>>> output = pixel_shuffle(input)
>>> print(output.size())
torch.Size([1, 1, 12, 12])

torch.nn.PixelUnshuffle

直观解释

PixelUnshuffle就是PixelShuffle的逆操作。

import torch
pus=torch.nn.PixelUnshuffle(2)
input_restore=pus(putput)
print(input_restore)
print(input_restore.shape)
print(input_restore==input) # input_restore和input一样

官方文档

CLASS
torch.nn.PixelUnshuffle(downscale_factor)
  • 功能: 是PixelShuffle的逆操作,把大小为 ( ∗ , C , H × r , W × r ) (*,C,H\times r,W\times r) (,C,H×r,W×r)的张量重组成大小为 ( ∗ , C × r , H , W ) (*,C\times r,H,W) (,C×r,H,W)的张量。其中 r r r 是downscale factor。

  • 参数:

    • downscale_factor (int) : 降低空间分辨率的因子。
  • 形状:

    • 输入: ( ∗ , C i n , H i n , W i n ) (*,C_{in},H_{in},W_{in}) (,Cin,Hin,Win), 其中 ∗ * 是 0 或者batch大小

    • 输出: ( ∗ , C o u t , H o u t , W o u t ) (*,C_{out},H_{out},W_{out}) (,Cout,Hout,Wout), 其中

      C out  = C in  ×  downscale  _ factor  2 H out  = H in  ÷  downscale  _ factor  W out  = W in  ÷  downscale  _ factor  \begin{aligned}& C_{\text {out }}=C_{\text {in }} \times \text { downscale } \_ \text {factor }{ }^2 \\& H_{\text {out }}=H_{\text {in }} \div \text { downscale } \_ \text {factor } \\& W_{\text {out }}=W_{\text {in }} \div \text { downscale } \_ \text {factor }\end{aligned} Cout =Cin × downscale _factor 2Hout =Hin ÷ downscale _factor Wout =Win ÷ downscale _factor 

  • 例子

>>> pixel_unshuffle = nn.PixelUnshuffle(3)
>>> input = torch.randn(1, 1, 12, 12)
>>> output = pixel_unshuffle(input)
>>> print(output.size())
torch.Size([1, 9, 4, 4])
  • 4
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

zyw2002

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值