pytorch逆亚像素

#-*-encoding:utf-8-*-
"""
# function/功能 : 
# @File : 测试亚像素.py 
# @Time : 2021/1/26 9:33 
# @Author : kf
# @Software: PyCharm
"""
import torch
seed=10
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True



# 逆亚像素卷积1
def de_subpix(y):
    # print('索引为偶数的项:', demo_list[::2])
    # print('索引为奇数的项:', demo_list[1::2])
    d1=y[:, :, ::2,::2]
    d2=y[:, :, 1::2,::2]
    d3=y[:, :, ::2,1::2]
    d4=y[:, :, 1::2,1::2]
    out = torch.cat([d1, d2, d3, d4], 1)
    return out

# 逆亚像素卷积2
def de_subpix2(y):
    (b, c, h, w) = y.shape
    h1 = int(h // 2)
    w1 = int(w // 2)
    d1 = torch.zeros((b, c, h1, w1))
    d2 = torch.zeros((b, c, h1, w1))
    d3 = torch.zeros((b, c, h1, w1))
    d4 = torch.zeros((b, c, h1, w1))
    for i in range(0, h1, 1):
        for j in range(0, w1, 1):
            d1[:, :, i, j] = y[:, :, 2 * i, 2 * j]
            d2[:, :, i, j] = y[:, :, 2 * i + 1, 2 * j]
            d3[:, :, i, j] = y[:, :, 2 * i, 2 * j + 1]
            d4[:, :, i, j] = y[:, :, 2 * i + 1, 2 * j + 1]
    out = torch.cat([d1, d2, d3, d4], 1)
    # print(out.shape)
    return out

# 逆亚像素卷积3
def de_pixelshuffle(input, downscale_factor):     # channal
    batch_size, channels, in_height, in_width = input.size()
    out_height = in_height // downscale_factor
    out_width = in_width // downscale_factor
    input_view = input.contiguous().view(batch_size, channels, out_height, downscale_factor, out_width, downscale_factor)
    channels = channels *downscale_factor ** 2
    shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()
    shuffle_out = shuffle_out.view(batch_size, channels, out_height, out_width)
    return shuffle_out



test_x = (torch.rand(1, 16, 16, 16))
print('test_x: {}'.format(test_x.shape))

out1=de_pixelshuffle(test_x,2)
print('output: {}'.format(out1.shape))


out=de_subpix(test_x)

out2=de_subpix2(test_x)

print(out)
print('output: {}'.format(out.shape))

# 亚像素卷积
ps = torch.nn.PixelShuffle(2)
outup=ps(test_x)
print('outup: {}'.format(outup.shape))

在这里插入图片描述
在这里插入图片描述

在使用3个不同逆亚像素过程中,发现

de_subpix和de_subpix2结果相同,de_pixelshuffle结果不同,这是因为de_pixelshuffle被压缩成一维,不能够提取准确位置,因此不能够使用进行逆亚像素。

不使用for循环,是因为索引更快。

时间对比,单位是s:

de_subpix:0.0
de_subpix2:0.00401616096496582

通过::x进行索引,得到x的整数倍索引

demo_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
x=demo_list[::2]
demo_list[1::2]
Out[4]: [1, 3, 5, 7, 9]
demo_list[::3]
Out[5]: [0, 3, 6, 9]
demo_list[1::3]
Out[6]: [1, 4, 7]
demo_list[2::3]
Out[7]: [2, 5, 8]

最终采用方法:de_subpix

参考1:https://blog.csdn.net/aaa958099161/article/details/90230541?utm_medium=distribute.pc_relevant.none-task-blog-searchFromBaidu-2.control&depth_1-utm_source=distribute.pc_relevant.none-task-blog-searchFromBaidu-2.control

参考2:https://blog.csdn.net/qq_38818384/article/details/106904989

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值