FusionMamba

class PixelShuffle(nn.Module):
    def __init__(self, dim, scale):
        super().__init__()
        # 调整卷积层,确保输入的通道数能整除 scale^2
        self.upsample = nn.Sequential(
            nn.Conv2d(dim, dim * (scale ** 2), 3, 1, 1, bias=False),  # 将通道数变为 dim * scale^2
            nn.PixelShuffle(scale)
        )

    def forward(self, x):
        return self.upsample(x)


class U2Net(nn.Module):
    def __init__(self, dim, pan_dim, ms_dim, H=256, W=256, scale=4):
        super().__init__()

        self.upsample = PixelShuffle(ms_dim, scale)  # 这里需要确保ms_dim的通道数合适
        self.raise_pan_dim = nn.Sequential(
            nn.Conv2d(pan_dim, dim, 3, 1, 1),
            nn.LeakyReLU()
        )
        self.raise_ms_dim = nn.Sequential(
            nn.Conv2d(ms_dim, dim, 3, 1, 1),
            nn.LeakyReLU()
        )
        self.to_hrms = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1),
            nn.LeakyReLU(),
            nn.Conv2d(dim, ms_dim, 3, 1, 1)
        )

        # dimension for each stage
        dim0 = dim
        dim1 = int(dim0 * 2)
        dim2 = int(dim1 * 2)

        # main body
        self.stage0 = Stage(dim0, dim1, H, W, sample_mode='down')
        self.stage1 = Stage(dim1, dim2, H//2, W//2, sample_mode='down')
        self.stage2 = Stage(dim2, dim1, H//4, W//4, sample_mode='up')
        self.stage3 = Stage(dim1, dim0, H//2, W//2, sample_mode='up')
        self.stage4 = FusionMamba(dim0, H, W, final=True)

        self.spe_attn = SpeAttention(ms_dim, 16, 'mamba', dim)

    def forward(self, input):
        ms = input[0]
        pan = input[1]
        lrms = ms
        ms = self.upsample(ms)  # 这里会通过PixelShuffle调整通道数
        skip = ms
        pan = self.raise_pan_dim(pan)
        ms = self.raise_ms_dim(ms)

        # main body
        pan, ms, pan_skip0, ms_skip0 = self.stage0(pan, ms)
        pan, ms, pan_skip1, ms_skip1 = self.stage1(pan, ms)
        pan, ms = self.stage2(pan, ms, pan_skip1, ms_skip1)
        pan, ms = self.stage3(pan, ms, pan_skip0, ms_skip0)
        output = self.stage4(pan, ms)

        spe_attn = self.spe_attn(lrms)

        output = self.to_hrms(output) * spe_attn + skip

        return output

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值