class PixelShuffle(nn.Module): def __init__(self, dim, scale): super().__init__() # 调整卷积层,确保输入的通道数能整除 scale^2 self.upsample = nn.Sequential( nn.Conv2d(dim, dim * (scale ** 2), 3, 1, 1, bias=False), # 将通道数变为 dim * scale^2 nn.PixelShuffle(scale) ) def forward(self, x): return self.upsample(x) class U2Net(nn.Module): def __init__(self, dim, pan_dim, ms_dim, H=256, W=256, scale=4): super().__init__() self.upsample = PixelShuffle(ms_dim, scale) # 这里需要确保ms_dim的通道数合适 self.raise_pan_dim = nn.Sequential( nn.Conv2d(pan_dim, dim, 3, 1, 1), nn.LeakyReLU() ) self.raise_ms_dim = nn.Sequential( nn.Conv2d(ms_dim, dim, 3, 1, 1), nn.LeakyReLU() ) self.to_hrms = nn.Sequential( nn.Conv2d(dim, dim, 3, 1, 1), nn.LeakyReLU(), nn.Conv2d(dim, ms_dim, 3, 1, 1) ) # dimension for each stage dim0 = dim dim1 = int(dim0 * 2) dim2 = int(dim1 * 2) # main body self.stage0 = Stage(dim0, dim1, H, W, sample_mode='down') self.stage1 = Stage(dim1, dim2, H//2, W//2, sample_mode='down') self.stage2 = Stage(dim2, dim1, H//4, W//4, sample_mode='up') self.stage3 = Stage(dim1, dim0, H//2, W//2, sample_mode='up') self.stage4 = FusionMamba(dim0, H, W, final=True) self.spe_attn = SpeAttention(ms_dim, 16, 'mamba', dim) def forward(self, input): ms = input[0] pan = input[1] lrms = ms ms = self.upsample(ms) # 这里会通过PixelShuffle调整通道数 skip = ms pan = self.raise_pan_dim(pan) ms = self.raise_ms_dim(ms) # main body pan, ms, pan_skip0, ms_skip0 = self.stage0(pan, ms) pan, ms, pan_skip1, ms_skip1 = self.stage1(pan, ms) pan, ms = self.stage2(pan, ms, pan_skip1, ms_skip1) pan, ms = self.stage3(pan, ms, pan_skip0, ms_skip0) output = self.stage4(pan, ms) spe_attn = self.spe_attn(lrms) output = self.to_hrms(output) * spe_attn + skip return output
FusionMamba
于 2024-12-16 14:35:52 首次发布