class transDDPM(nn.Module): # 解码器
def __init__(self, in_xC, out_C): # in_xC为144,out_C为64
super().__init__()
self.down_input = nn.Conv2d(out_C, out_C//4, 1)
self.unfold1 = nn.Unfold(kernel_size=3, dilation=1, padding=1, stride=1) # 将图片切割成patch
self.unfold3 = nn.Unfold(kernel_size=3, dilation=3, padding=3, stride=1) # 以kernel_size作为patch
self.unfold5 = nn.Unfold(kernel_size=3, dilation=5, padding=5, stride=1) # 输出大小为[n, c*k*k, h*w]
self.fuse = BasicConv2d(out_C, out_C, 3, 1, 1)
def forward(self, x, y): # x为F^i_e,y为经transformer后的x
x = self.down_input(x)
N, xC, xH, xW = x.size()
unfold_x1 = self.unfold1(x).reshape([N, xC, -1, xH, xW]) # unfold输出大小为[n, c*k*k, h*w]
unfold_x2 = self.unfold3(x).reshape([N, xC, -1, xH, xW]) # reshape输出大小为[n, c, k*k, h, w]
unfold_x3 = self.unfold5(x).reshape([N, xC, -1, xH, xW])
y0 = y[0].reshape([N, xC, 9, xH, xW]) # y大小为[3,N,144,xH,xW]
y1 = y[1].reshape([N, xC, 9, xH, xW]) # y[i]大小为[N,xC,9,xH,xW]
y2 = y[2].reshape([N, xC, 9, xH, xW])
result1 = (unfold_x1 * y0).sum(2) # sum(2)对第三维求和
result2 = (unfold_x2 * y1).sum(2)
result3 = (unfold_x3 * y2).sum(2)
return self.fuse(torch.cat((x, result1, result2, result3), dim=1))
MVSalNet代码学习记录
于 2023-10-07 20:18:12 首次发布