[深度学习] SpaceToDepth 类

类代码

class SpaceToDepth(nn.Module):
    def __init__(self, block_size):
        super(SpaceToDepth, self).__init__()
        self.block_size = block_size
        self.block_size_sq = block_size*block_size

    def forward(self, input):
        output = input.permute(0, 2, 3, 1)
        (batch_size, s_height, s_width, s_depth) = output.size()
        d_depth = s_depth * self.block_size_sq
        d_width = int(s_width / self.block_size)
        d_height = int(s_height / self.block_size)
        t_1 = output.split(self.block_size, 2)
        stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]
        output = torch.stack(stack, 1)
        output = output.permute(0, 2, 1, 3)
        output = output.permute(0, 3, 1, 2)
        return output

逐句解析

labels = space2depth(labels)

将 labels 输入给 space2depth对象,调用对象里的 forward函数,input:(64,1,120,160)

output = input.permute(0, 2, 3, 1)

交换维度,output:(64,120,160,1)

(batch_size, s_height, s_width, s_depth) = output.size()

batch_size=64,s_height=120,s_width=160,s_depth=1

d_depth = s_depth * self.block_size_sq

其中, s e l f . b l o c k _ s i z e _ s q = b l o c k _ s i z e ∗ b l o c k _ s i z e = 8 ∗ 8 = 64 self.block\_size\_sq=block\_size*block\_size=8*8=64 self.block_size_sq=block_sizeblock_size=88=64
d_depth=64

d_width = int(s_width / self.block_size)
d_height = int(s_height / self.block_size)

d_width=20
d_height=15

t_1 = output.split(self.block_size, 2)

output:(64,120,160,1)

在 output 的第2维度上,按没 block_size=8,切块

len(t_1)=20

t_1[0].shape=(64, 120, 8, 1)

把batch size 看成1时,可以方便理解这步拆分,如下图:
在这里插入图片描述

stack = [t_t.reshape(batch_size, d_height, d_depth) for t_t in t_1]

对 t_1中的每一块(64,120,8,1),总共20块。

reshape成(64,15,64)

同样把batch size 看成1 用于方便理解这部reshape,如下图:
在这里插入图片描述
len(stack)=20

output = torch.stack(stack, 1)

把 stack中的 20块,在第1维度堆叠起来:
在这里插入图片描述
output:(64,20,15,64)

output = output.permute(0, 2, 1, 3)

output:(64,15,20,64)

output = output.permute(0, 3, 1, 2)

output:(64,64,15,20)

return 返回,回到 labels = space2depth(labels)

也就是说,(64,1,120,160)的labels 经过 space2depth的处理,转换成了 (64,64,15,20)的labels

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值