DBnet的具体实现
1.FPN(主干为resnet50)
2.DB (两次二值化,得到prob_map,threshold_map)
3.Segout
torch.reciprocal(1 + torch.exp(-k * (prob_map - threshold_map)))(建议k=10)
pytorch实现
import torch
import torch.nn as nn
#class:Res,Resnet50,FPN,SegoutDetector->DBnet
class DBnet(nn.Module):
def __init__(self,serial=False):
super(DBnet,self).__init__()
self.backbone=Resnet50()
self.head=FPN()
self.seg_out=SegoutDetector(serial=serial)
#返回prob_map, threshold_map, ab_map,测试时只返回prob_map
def forward(self,x):
return self.seg_out(self.head(self.backbone(x)))
class Res(nn.Module):
#stride=2时缩小特征图尺寸
def __init__(self,in_channel,inner_channel,stride=1,):
super(Res,self).__init__()
self.expansion = 4
self.bottleneck=nn.Sequential(
nn.Conv2d(in_channel,inner_channel,1,bias=False),
nn.BatchNorm2d(inner_channel),
nn.ReLU(inplace=True),
nn.Conv2d(inner_channel,inner_channel,3,stride,1,bias=False),
nn.BatchNorm2d(inner_channel),
nn.ReLU(inplace=True),
nn.Conv2d(inner_channel,self.expansion*inner_channel,1,bias=False),
nn.BatchNorm2d(self.expansion*inner_channel),
)
self.relu=nn.ReLU(inplace=True)
#若输入通道与输出通道数不同或输入输出尺寸发生变化,对原图进行下采样,再相加
self.dsample=None
if stride != 1 or in_channel != self.expansion * inner_channel:
self.dsample = nn.Sequential(
nn.Conv2d(in_channel, self.expansion * inner_channel, 1, stride, bias=False),
nn.BatchNorm2d(self.expansion * inner_channel)
)
def forward(self,x):
identity=x
out=self.bottleneck(x)
if self.dsample is not None :
identity=self.dsample(x)
out+=identity
out=self.relu(out)
return out
class Resnet50(nn.Module):
def __init__(self):
super(Resnet50,self).__init__()
self.make_c1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=64,kernel_size=7,stride=2,padding=3,bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
)
#[3,4,6,3]
self.make_c2=nn.Sequential(Res(in_channel=64,inner_channel=64,stride=1),
Res(in_channel=256, inner_channel=64, stride=1),
Res(in_channel=256, inner_channel=64, stride=1)
)
self.make_c3=nn.Sequential(Res(in_channel=256,inner_channel=128,stride=2),
Res(in_channel=512,inner_channel=128,stride=1),
Res(in_channel=512, inner_channel=128, stride=1),
Res(in_channel=512, inner_channel=128, stride=1),
)
self.make_c4=nn.Sequential(Res(in_channel=512,inner_channel=256,stride=2),
Res(in_channel=1024,inner_channel=256,stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
Res(in_channel=1024, inner_channel=256, stride=1),
)
self.make_c5=nn.Sequential(Res(in_channel=1024,inner_channel=512,stride=2),
Res(in_channel=2048,inner_channel=512,stride=1),
Res(in_channel=2048, inner_channel=512, stride=1),
)
def forward(self,x):
c1=self.make_c1(x)
c2=self.make_c2(c1)
c3=self.make_c3(c2)
c4=self.make_c4(c3)
c5=self.make_c5(c4)
return c2,c3,c4,c5
class FPN(nn.Module):
def __init__(self):
super(FPN,self).__init__()
self.make_p5=nn.Sequential(nn.Conv2d(2048,256,1,1,0,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
#横向连接,保证通道数相同
self.lat_c4=nn.Sequential(nn.Conv2d(1024,256,1,1,0,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
self.lat_c3=nn.Sequential(nn.Conv2d(512,256,1,1,0,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
self.lat_c2=nn.Sequential(nn.Conv2d(256,256,1,1,0,bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True))
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight.data)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
def _upsample_add(self,x,y):
upsample=nn.Upsample(size=(y.shape[2],y.shape[3]),mode='nearest')
return y+upsample(x)
def forward(self,x):
c2, c3, c4, c5=x
p5,p4,p3,p2=self.make_p5(c5),self.lat_c4(c4),self.lat_c3(c3),self.lat_c2(c2)
p4,p3,p2=self._upsample_add(p5,p4),self._upsample_add(p4,p3),self._upsample_add(p3,p2)
return p2,p3,p4,p5
class SegoutDetector(nn.Module):
def __init__(self,serial=False):
super(SegoutDetector,self).__init__()
self.conv3 = nn.Sequential(nn.Conv2d(256, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True))
# probability map
self.binarize = nn.Sequential(
nn.Conv2d(256, 64, 3, 1, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 64, 2, 2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, out_channels=1, kernel_size=2, stride=2),
nn.Sigmoid()
)
self.binarize.apply(self.weights_init)
def forward(self, x):
p2, p3, p4, p5=x
fuse = self.merge(p2, p3, p4, p5)
# probability map
prob_map = self.binarize(fuse)
#测试时只返回概率图
if not self.training:
return prob_map
# threshold map
threshold_map = self.binarize(fuse)
# approximate binary map
ab_map = self.ab_map(prob_map, threshold_map)
return prob_map,threshold_map,ab_map
def merge(self, p2, p3, p4, p5):
self.upsample = nn.Upsample(size=(p2.shape[2], p2.shape[3]), mode='nearest')
p2 = self.conv3(p2)
p3 = self.conv3(p3)
p3 = self.upsample(p3)
p4 = self.conv3(p4)
p4 = self.upsample(p4)
p5 = self.conv3(p5)
p5 = self.upsample(p5)
return torch.cat((p2, p3, p4, p5), dim=1)
# approximate binary map
def ab_map(self, x, y, k=10):
return torch.reciprocal(1 + torch.exp(-k * (x - y)))
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
if __name__=="__main__":
db=DBnet()
print(len(db.state_dict()))
x=torch.randn(2,3,512,512)
p,t,pt=db(x)
print(p.shape,t.shape,pt.shape)