duiqi

import torch
import torch.nn as nn
import torch.nn.functional as F
# from inplace_abn import InPlaceABN, InPlaceABNSync

class CAB_low(nn.Module):
    def __init__(self, features):
        super(CAB_low, self).__init__()

        self.conv1 = nn.Conv2d(features * 4, features, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(features)

        self.maxpooling = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(features * 4, 1, kernel_size=1)

        self.conv3 = nn.Conv2d(features * 2, features, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(features)

        self.delta_gen1 = nn.Sequential(
            nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
            # InPlaceABNSync(features),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
        )

        self.delta_gen2 = nn.Sequential(
            nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
            # InPlaceABNSync(features),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
        )

        self.delta_gen1[3].weight.data.zero_()
        self.delta_gen2[3].weight.data.zero_()

    # https://github.com/speedinghzl/AlignSeg/issues/7
    # the normlization item is set to [w/s, h/s] rather than [h/s, w/s]
    # the function bilinear_interpolate_torch_gridsample2 is standard implementation, please use bilinear_interpolate_torch_gridsample2 for training.
    def bilinear_interpolate_torch_gridsample(self, input, size, delta=0):
        out_h, out_w = size
        n, c, h, w = input.shape
        s = 1.0
        norm = torch.tensor([[[[w / s, h / s]]]]).type_as(input).to(input.device)
        w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
        h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
        grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
        grid = grid + delta.permute(0, 2, 3, 1) / norm

        output = F.grid_sample(input, grid)
        return output

    def bilinear_interpolate_torch_gridsample2(self, input, size, delta=0):
        out_h, out_w = size
        n, c, h, w = input.shape
        s = 2.0
        norm = torch.tensor([[[[(out_w - 1) / s, (out_h - 1) / s]]]]).type_as(input).to(input.device)  # not [h/s, w/s]
        w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
        h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
        grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
        grid = grid + delta.permute(0, 2, 3, 1) / norm

        output = F.grid_sample(input, grid, align_corners=True)
        return output

    def forward(self, low_stage, high_stage):
        high_stage1 = F.relu(self.bn1(self.conv1(high_stage)))
        h, w = low_stage.size(2), low_stage.size(3)
        high_stage1 = F.interpolate(input=high_stage1, size=(h, w), mode='bilinear', align_corners=True)
        concat = torch.cat((low_stage, high_stage1), 1)
        delta1 = self.delta_gen1(concat)
        delta2 = self.delta_gen2(concat)
        high_stage1 = self.bilinear_interpolate_torch_gridsample(high_stage1, (h, w), delta1)
        low_stage1 = self.bilinear_interpolate_torch_gridsample(low_stage, (h, w), delta2)
        high_stage1 += low_stage1

        if low_stage.size(2) == high_stage.size(2):
            low_stage2 = low_stage
            low_stage2 = F.sigmoid(self.conv2(low_stage2))

            high_stage2 = torch.mul(high_stage, low_stage2)
            high_stage2 = torch.cat((high_stage2, high_stage), dim=1)
            high_stage2 = F.relu(self.bn3(self.conv3(high_stage2)))
        else:
            # low_stage2 = self.maxpooling(low_stage)
            # low_stage2 = F.sigmoid(self.conv2(low_stage2))

            high_stage2 = F.interpolate(input=high_stage, scale_factor=2, mode='bilinear', align_corners=True)
            high_stage2 = F.sigmoid(self.conv2(high_stage2))

            # high_stage2 = torch.mul(high_stage, low_stage2)
            # high_stage2 = torch.cat((high_stage2, high_stage), dim=1)
            # high_stage2 = F.relu(self.bn3(self.conv3(high_stage2)))

            low_stage2 = torch.mul(low_stage, high_stage2)
            low_stage2 = torch.cat((low_stage2, low_stage), dim=1)
            low_stage2 = F.relu(self.bn3(self.conv3(low_stage2)))

        return high_stage1, low_stage2


class CAB_high(nn.Module):
    def __init__(self, features):
        super(CAB_high, self).__init__()

        self.conv1 = nn.Conv2d(features * 4, features, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(features)

        self.maxpooling = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(features, 1, kernel_size=1)

        self.conv3 = nn.Conv2d(features * 8, features, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(features)

        self.delta_gen1 = nn.Sequential(
            nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
            # InPlaceABNSync(features),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
        )

        self.delta_gen2 = nn.Sequential(
            nn.Conv2d(features * 2, features, kernel_size=1, bias=False),
            # InPlaceABNSync(features),
            nn.BatchNorm2d(features),
            nn.ReLU(),
            nn.Conv2d(features, 2, kernel_size=3, padding=1, bias=False)
        )

        self.delta_gen1[3].weight.data.zero_()
        self.delta_gen2[3].weight.data.zero_()

    # https://github.com/speedinghzl/AlignSeg/issues/7
    # the normlization item is set to [w/s, h/s] rather than [h/s, w/s]
    # the function bilinear_interpolate_torch_gridsample2 is standard implementation, please use bilinear_interpolate_torch_gridsample2 for training.
    def bilinear_interpolate_torch_gridsample(self, input, size, delta=0):
        out_h, out_w = size
        n, c, h, w = input.shape
        s = 1.0
        norm = torch.tensor([[[[w / s, h / s]]]]).type_as(input).to(input.device)
        w_list = torch.linspace(-1.0, 1.0, out_h).view(-1, 1).repeat(1, out_w)
        h_list = torch.linspace(-1.0, 1.0, out_w).repeat(out_h, 1)
        grid = torch.cat((h_list.unsqueeze(2), w_list.unsqueeze(2)), 2)
        grid = grid.repeat(n, 1, 1, 1).type_as(input).to(input.device)
        grid = grid + delta.permute(0, 2, 3, 1) / norm

        output = F.grid_sample(input, grid)
        return output

    def bilinear_interpolate_torch_gridsample2(self, input, size, delta=0):
        out_h, out_w = size
        n, c, h, w = input.shape
        s = 2.0
        norm = torch.tensor([[

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值