1、basnet边缘注意力分割网络原理介绍:
耗时:在自己gtx1070Ti的电脑上分割尺寸(224,1120)得出的耗时是: 56ms左右波动
另一个比basnet好的网络U^2-net在公众号各种好的工程二, 这个网络耗时在相同硬件下,224的输入,耗时112ms。
basnet主要使用了一个Residual Refinement Module跟SSIM loss 其中SSIM(structural similarity index),其解码网络主要有7个输出,然后通过Residual Refinement Module进行输出一个,总共有8个输出,他们的输出W H大小跟输入数据的W H一样。
basnet网络其目标分割效果很好,除了使用上面的RRM网络结构外还有一个是使用了SSIM loss 这个损失会计算预测输出跟打标mask的亮度、对比度、图像结构这三者的损失。
- BASNET网络的结构如下:
注意:上面是一整个网络,step1-8是一个大小跟输出一样的分割图,其是使用sigmoid进行激活的,而step8是使用了step7,然后再使用了一个RRM网络进行输出,效果出奇的好。
下面是basnet网络的输出部分的forward部分代码:
def forward(self,x):
.....
## -------------Side Output-------------
db = self.outconvb(hbg)
db = self.upscore6(db) # 8->256
d6 = self.outconv6(hd6)
d6 = self.upscore6(d6) # 8->256
d5 = self.outconv5(hd5)
d5 = self.upscore5(d5) # 16->256
d4 = self.outconv4(hd4)
d4 = self.upscore4(d4) # 32->256
d3 = self.outconv3(hd3)
d3 = self.upscore3(d3) # 64->256
d2 = self.outconv2(hd2)
d2 = self.upscore2(d2) # 128->256
d1 = self.outconv1(hd1) # 256
## -------------Refine Module-------------
dout = self.refunet(d1) # 256
return F.sigmoid(dout), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(
d6), F.sigmoid(db)
训练时候的输出计算:
# forward + backward + optimize
d0, d1, d2, d3, d4, d5, d6, d7 = net(inputs_v)
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v)
loss.backward()
optimizer.step()
muti_bce_loss_fusion函数的计算为:
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, d7, labels_v):
loss0 = bce_ssim_loss(d0, labels_v)
loss1 = bce_ssim_loss(d1, labels_v)
loss2 = bce_ssim_loss(d2, labels_v)
loss3 = bce_ssim_loss(d3, labels_v)
loss4 = bce_ssim_loss(d4, labels_v)
loss5 = bce_ssim_loss(d5, labels_v)
loss6 = bce_ssim_loss(d6, labels_v)
loss7 = bce_ssim_loss(d7, labels_v)
# ssim0 = 1 - ssim_loss(d0,labels_v)
# iou0 = iou_loss(d0,labels_v)
# loss = torch.pow(torch.mean(torch.abs(labels_v-d0)),2)*(5.0*loss0 + loss1 + loss2 + loss3 + loss4 + loss5) #+ 5.0*lossa
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6 + loss7 # + 5.0*lossa
print("l0: %3f, l1: %3f, l2: %3f, l3: %3f, l4: %3f, l5: %3f, l6: %3f\n" % (
loss0.data, loss1.data, loss2.data, loss3.data, loss4.data, loss5.data, loss6.data))
# print("BCE: l1:%3f, l2:%3f, l3:%3f, l4:%3f, l5:%3f, la:%3f, all:%3f\n"%(loss1.data[0],loss2.data[0],loss3.data[0],loss4.data[0],loss5.data[0],lossa.data[0],loss.data[0]))
return loss0, loss
2、SSIM_loss(Structural SIMilarity)的计算原理:
其中basnet的论文里时有三个loss的分别是:
而参考工程里没有计算bce、iou。
使用pytorch的conv2d进行计算ssim损失的参考代码,SSIM构建的window就是通过构建卷积核大小来实现的。
参考博客:损失函数SSIM (structural similarity index) 的PyTorch实现
其中SSIM的loss公式原理介绍,论文笔记-损失函数之SSIM