MVSalNet代码学习记录

class transDDPM(nn.Module):     # 解码器
    def __init__(self, in_xC, out_C):       # in_xC为144,out_C为64
        super().__init__()
        self.down_input = nn.Conv2d(out_C, out_C//4, 1)
        self.unfold1 = nn.Unfold(kernel_size=3, dilation=1, padding=1, stride=1)    # 将图片切割成patch
        self.unfold3 = nn.Unfold(kernel_size=3, dilation=3, padding=3, stride=1)    # 以kernel_size作为patch
        self.unfold5 = nn.Unfold(kernel_size=3, dilation=5, padding=5, stride=1)    # 输出大小为[n, c*k*k, h*w]
        self.fuse = BasicConv2d(out_C, out_C, 3, 1, 1)

    def forward(self, x, y):    # x为F^i_e,y为经transformer后的x
        x = self.down_input(x)
        N, xC, xH, xW = x.size()
        unfold_x1 = self.unfold1(x).reshape([N, xC, -1, xH, xW])        # unfold输出大小为[n, c*k*k, h*w]
        unfold_x2 = self.unfold3(x).reshape([N, xC, -1, xH, xW])        # reshape输出大小为[n, c, k*k, h, w]
        unfold_x3 = self.unfold5(x).reshape([N, xC, -1, xH, xW])

        y0 = y[0].reshape([N, xC, 9, xH, xW])       # y大小为[3,N,144,xH,xW]
        y1 = y[1].reshape([N, xC, 9, xH, xW])       # y[i]大小为[N,xC,9,xH,xW]
        y2 = y[2].reshape([N, xC, 9, xH, xW])
        result1 = (unfold_x1 * y0).sum(2)       # sum(2)对第三维求和
        result2 = (unfold_x2 * y1).sum(2)
        result3 = (unfold_x3 * y2).sum(2)
        return self.fuse(torch.cat((x, result1, result2, result3), dim=1))

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值