DBnet实现

DBnet的具体实现

1.FPN(主干为resnet50)
fpn
2.DB (两次二值化,得到prob_map,threshold_map)
二值化

3.Segout
torch.reciprocal(1 + torch.exp(-k * (prob_map - threshold_map)))(建议k=10)

pytorch实现

import torch
import torch.nn as nn

#class:Res,Resnet50,FPN,SegoutDetector->DBnet

class DBnet(nn.Module):
    def __init__(self,serial=False):
        super(DBnet,self).__init__()
        self.backbone=Resnet50()
        self.head=FPN()
        self.seg_out=SegoutDetector(serial=serial)

    #返回prob_map, threshold_map, ab_map,测试时只返回prob_map
    def forward(self,x):
        return self.seg_out(self.head(self.backbone(x)))

class Res(nn.Module):
    #stride=2时缩小特征图尺寸
    def __init__(self,in_channel,inner_channel,stride=1,):
        super(Res,self).__init__()
        self.expansion = 4
        self.bottleneck=nn.Sequential(
            nn.Conv2d(in_channel,inner_channel,1,bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_channel,inner_channel,3,stride,1,bias=False),
            nn.BatchNorm2d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_channel,self.expansion*inner_channel,1,bias=False),
            nn.BatchNorm2d(self.expansion*inner_channel),
        )
        self.relu=nn.ReLU(inplace=True)
        #若输入通道与输出通道数不同或输入输出尺寸发生变化,对原图进行下采样,再相加
        self.dsample=None
        if stride != 1 or in_channel != self.expansion * inner_channel:
            self.dsample = nn.Sequential(
                nn.Conv2d(in_channel, self.expansion * inner_channel, 1, stride, bias=False),
                nn.BatchNorm2d(self.expansion * inner_channel)
            )

    def forward(self,x):
        identity=x
        out=self.bottleneck(x)
        if self.dsample is not None :
            identity=self.dsample(x)
        out+=identity
        out=self.relu(out)
        return out

class Resnet50(nn.Module):
    def __init__(self):
        super(Resnet50,self).__init__()
        self.make_c1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3,bias=False),
                              nn.BatchNorm2d(64),
                              nn.ReLU(inplace=True),
                              nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
                              )
        #[3,4,6,3]
        self.make_c2=nn.Sequential(Res(in_channel=64,inner_channel=64,stride=1),
                              Res(in_channel=256, inner_channel=64, stride=1),
                              Res(in_channel=256, inner_channel=64, stride=1)
                              )
        self.make_c3=nn.Sequential(Res(in_channel=256,inner_channel=128,stride=2),
                              Res(in_channel=512,inner_channel=128,stride=1),
                              Res(in_channel=512, inner_channel=128, stride=1),
                              Res(in_channel=512, inner_channel=128, stride=1),
                              )
        self.make_c4=nn.Sequential(Res(in_channel=512,inner_channel=256,stride=2),
                              Res(in_channel=1024,inner_channel=256,stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              Res(in_channel=1024, inner_channel=256, stride=1),
                              )
        self.make_c5=nn.Sequential(Res(in_channel=1024,inner_channel=512,stride=2),
                              Res(in_channel=2048,inner_channel=512,stride=1),
                              Res(in_channel=2048, inner_channel=512, stride=1),
                              )

    def forward(self,x):
        c1=self.make_c1(x)
        c2=self.make_c2(c1)
        c3=self.make_c3(c2)
        c4=self.make_c4(c3)
        c5=self.make_c5(c4)
        return c2,c3,c4,c5

class FPN(nn.Module):
    def __init__(self):
        super(FPN,self).__init__()
        self.make_p5=nn.Sequential(nn.Conv2d(2048,256,1,1,0,bias=False),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True))
        #横向连接,保证通道数相同
        self.lat_c4=nn.Sequential(nn.Conv2d(1024,256,1,1,0,bias=False),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True))
        self.lat_c3=nn.Sequential(nn.Conv2d(512,256,1,1,0,bias=False),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True))
        self.lat_c2=nn.Sequential(nn.Conv2d(256,256,1,1,0,bias=False),
                                   nn.BatchNorm2d(256),
                                   nn.ReLU(inplace=True))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight.data)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1.)
                m.bias.data.fill_(1e-4)

    def _upsample_add(self,x,y):
        upsample=nn.Upsample(size=(y.shape[2],y.shape[3]),mode='nearest')
        return y+upsample(x)

    def forward(self,x):
        c2, c3, c4, c5=x
        p5,p4,p3,p2=self.make_p5(c5),self.lat_c4(c4),self.lat_c3(c3),self.lat_c2(c2)
        p4,p3,p2=self._upsample_add(p5,p4),self._upsample_add(p4,p3),self._upsample_add(p3,p2)
        return p2,p3,p4,p5

class SegoutDetector(nn.Module):
    def __init__(self,serial=False):
        super(SegoutDetector,self).__init__()
        self.conv3 = nn.Sequential(nn.Conv2d(256, 64, 3, 1, 1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.ReLU(inplace=True))
        # probability map
        self.binarize = nn.Sequential(
            nn.Conv2d(256, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, 2, 2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, out_channels=1, kernel_size=2, stride=2),
            nn.Sigmoid()
        )
        self.binarize.apply(self.weights_init)

    def forward(self, x):
        p2, p3, p4, p5=x
        fuse = self.merge(p2, p3, p4, p5)
        # probability map
        prob_map = self.binarize(fuse)
        #测试时只返回概率图
        if not self.training:
            return prob_map
        # threshold map
        threshold_map = self.binarize(fuse)
        # approximate binary map
        ab_map = self.ab_map(prob_map, threshold_map)

        return prob_map,threshold_map,ab_map

    def merge(self, p2, p3, p4, p5):
        self.upsample = nn.Upsample(size=(p2.shape[2], p2.shape[3]), mode='nearest')
        p2 = self.conv3(p2)
        p3 = self.conv3(p3)
        p3 = self.upsample(p3)
        p4 = self.conv3(p4)
        p4 = self.upsample(p4)
        p5 = self.conv3(p5)
        p5 = self.upsample(p5)
        return torch.cat((p2, p3, p4, p5), dim=1)

    # approximate binary map
    def ab_map(self, x, y, k=10):
        return torch.reciprocal(1 + torch.exp(-k * (x - y)))

    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.kaiming_normal_(m.weight.data)
        elif classname.find('BatchNorm') != -1:
            m.weight.data.fill_(1.)
            m.bias.data.fill_(1e-4)

if __name__=="__main__":
    db=DBnet()
    print(len(db.state_dict()))
    x=torch.randn(2,3,512,512)
    p,t,pt=db(x)
    print(p.shape,t.shape,pt.shape)




  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值