DFN-Model(pytorch版本)

D F N − M o d e l ( p y t o r c h 版 本 ) DFN-Model(pytorch版本) DFNModel(pytorch)

解释一下BorderNetWork是如何起作用的,是如何进行融合的,有问题,欢迎交流

1.先看损失函数

# -*- coding: utf-8 -*-

import tensorflow as tf

def pw_softmaxwithloss_2d(y_true, y_pred):
	
	exp_pred = tf.exp(y_pred)
	
	try:
		
		sum_exp = tf.reduce_sum(exp_pred, 3, keepdims=True)
	
	except:
		
		sum_exp = tf.reduce_sum(exp_pred, 3, keep_dims=True)
	
	tensor_sum_exp = tf.tile(sum_exp, tf.stack([1, 1, 1, tf.shape(y_pred)[3]]))
	softmax_output = tf.div(exp_pred, tensor_sum_exp)
	ce = - tf.reduce_mean(y_true * tf.log(tf.clip_by_value(softmax_output, 1e-12, 1.0)))
	
	return softmax_output, ce

def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0):
	
	try:
		
		pk = tf.reduce_sum(y_true * y_pred, 3, keepdims=True)
	
	except:
		
		pk = tf.reduce_sum(y_true * y_pred, 3, keep_dims=True)
	
	fl = - alpha * tf.reduce_mean(tf.pow(1.0 - pk, gamma) * tf.log(tf.clip_by_value(pk, 1e-12, 1.0)))
	
	return fl

总共有两个损失函数def pw_softmaxwithloss_2d(y_true, y_pred)def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0)

# loss1
def pw_softmaxwithloss_2d(y_true, y_pred)
# loss2
def focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0)

2.看网络

在这里插入图片描述

整个DFN网络分Border NetworkSmooth Network两个子网络

Smooth Network网络用于图像分割,使用的是loss是pw_softmaxwithloss_2d(y_true, y_pred)
Border Network网络用于模型关注网络边界,使用的是loss是focal_loss(y_true, y_pred, alpha=0.25, gamma=2.0)

这里我们从两个loss可以看到,Border NetworkSmooth Network网络的输出一样的特征图y_pred

实际模型的输出也是一样的y_pred特征图,如下

# 随机生成输入数据
image = torch.randn(1, 3, 512, 512)
# 定义网络
net = DFN(21)
net.eval()
# 前向传播
res1, res2 = net(image)
# 打印输出大小
print('-----'*5)
print(res1.size(), res2.size())
print('-----'*5)

在这里插入图片描述

然后两个子网络结合标签y_true计算出loss

下面是真正训练时候的loss设计

	def loss(self):
		
		######### -*- Softmax Loss -*- #########
		self.softmax_b1, self.ce1 = pw_softmaxwithloss_2d(self.Y, self.b1)
		self.softmax_b2, self.ce2 = pw_softmaxwithloss_2d(self.Y, self.b2)
		self.softmax_b3, self.ce3 = pw_softmaxwithloss_2d(self.Y, self.b3)
		self.softmax_b4, self.ce4 = pw_softmaxwithloss_2d(self.Y, self.b4)
		self.softmax_fuse, self.cefuse = pw_softmaxwithloss_2d(self.Y, self.fuse)
		self.total_ce = self.ce1 + self.ce2 + self.ce3 + self.ce4 + self.cefuse
		
		######### -*- Focal Loss -*- #########
		self.fl = focal_loss(self.Y, self.o, alpha=self.alpha, gamma=self.gamma)
		
		######### -*- Total Loss -*- #########
		self.total_loss = self.total_ce + self.fl_weight * self.fl

可以看到

self.total_loss = self.total_ce + self.fl_weight * self.fl

所以这就是双网络融合,通过loss融合达到双网络的“融合”训练
使用focal loss监督Border Network的输出
Border Network,因为Border Network的网络设计更关注边界,所以叫Border Network

整体解释图

在这里插入图片描述

这里再做一个强调

	def loss(self):
		
		######### -*- Softmax Loss -*- #########
		1.self.softmax_b1, self.ce1 = pw_softmaxwithloss_2d(self.Y, self.b1)
		2.self.softmax_b2, self.ce2 = pw_softmaxwithloss_2d(self.Y, self.b2)
		3.self.softmax_b3, self.ce3 = pw_softmaxwithloss_2d(self.Y, self.b3)
		
		4.self.softmax_b4, self.ce4 = pw_softmaxwithloss_2d(self.Y, self.b4)
		5.self.softmax_fuse, self.cefuse = pw_softmaxwithloss_2d(self.Y, self.fuse)
		6.self.total_ce = self.ce1 + self.ce2 + self.ce3 + self.ce4 + self.cefuse
		
		######### -*- Focal Loss -*- #########
		7.self.fl = focal_loss(self.Y, self.o, alpha=self.alpha, gamma=self.gamma)
		
		######### -*- Total Loss -*- #########
		self.total_loss = self.total_ce + self.fl_weight * self.fl
	

看上图的loss,我们可以看到1-6行和第7行都是self.Y,所以两个子网络都是用的一样的标签(t_true)

训练、验证代码逻辑




All.ipynb


在这里插入图片描述

在这里插入图片描述


from PIL import Image
import torch.nn as nn
from torchvision import models
import torch
resnet101 = models.resnet101(pretrained=False)
class RRB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(RRB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        res = self.conv2(x)
        res = self.bn(res)
        res = self.relu(res)
        res = self.conv3(res)
        return self.relu(x + res)
class CAB(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(CAB, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.sigmod = nn.Sigmoid()

    def forward(self, x):
        x1, x2 = x  # high, low
        x = torch.cat([x1, x2], dim=1)
        x = self.global_pooling(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmod(x)
        x2 = x * x2
        res = x2 + x1
        return res
class DFN(nn.Module):
    def __init__(self, num_class=21):
        super(DFN, self).__init__()
        self.num_class = num_class
        self.layer0 = nn.Sequential(resnet101.conv1, resnet101.bn1, resnet101.relu)
        self.layer1 = nn.Sequential(resnet101.maxpool, resnet101.layer1)
        self.layer2 = resnet101.layer2
        self.layer3 = resnet101.layer3
        self.layer4 = resnet101.layer4

        # this is for smooth network
        self.out_conv = nn.Conv2d(2048, self.num_class, kernel_size=1, stride=1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.cab1 = CAB(self.num_class*2, self.num_class)
        self.cab2 = CAB(self.num_class*2, self.num_class)
        self.cab3 = CAB(self.num_class*2, self.num_class)
        self.cab4 = CAB(self.num_class*2, self.num_class)

        self.rrb_d_1 = RRB(256, self.num_class)
        self.rrb_d_2 = RRB(512, self.num_class)
        self.rrb_d_3 = RRB(1024, self.num_class)
        self.rrb_d_4 = RRB(2048, self.num_class)

        self.upsample = nn.Upsample(scale_factor=2,mode="bilinear")
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear")
        self.upsample_8 = nn.Upsample(scale_factor=8, mode="bilinear")

        self.rrb_u_4 = RRB(self.num_class,self.num_class)
        self.rrb_u_3 = RRB(self.num_class,self.num_class)
        self.rrb_u_2 = RRB(self.num_class,self.num_class)
        self.rrb_u_1 = RRB(self.num_class,self.num_class)

        # this is for boarder net work
        self.rrb_db_1 = RRB(256, self.num_class)
        self.rrb_db_2 = RRB(512, self.num_class)
        self.rrb_db_3 = RRB(1024, self.num_class)
        self.rrb_db_4 = RRB(2048, self.num_class)

        self.rrb_trans_1 = RRB(self.num_class,self.num_class)
        self.rrb_trans_2 = RRB(self.num_class,self.num_class)
        self.rrb_trans_3 = RRB(self.num_class,self.num_class)

    def forward(self, x):
        f0 = self.layer0(x)  # 256, 256, 64
        f1 = self.layer1(f0)  # 128, 128, 256
        f2 = self.layer2(f1)  # 64, 64, 512
        f3 = self.layer3(f2)  # 32, 32, 1024
        f4 = self.layer4(f3)  # 16, 16, 2048

        # for border network
        res1 = self.rrb_db_1(f1)
        res1 = self.rrb_trans_1(res1 + self.upsample(self.rrb_db_2(f2)))
        res1 = self.rrb_trans_2(res1 + self.upsample_4(self.rrb_db_3(f3)))
        res1 = self.rrb_trans_3(res1 + self.upsample_8(self.rrb_db_4(f4)))      # 128, 128, 21

        # for smooth network
        res2 = self.out_conv(f4)    # 16, 16, 21
        res2 = self.global_pool(res2)  #
        res2 = nn.Upsample(size=f4.size()[2:],mode="nearest")(res2)     # 16, 16, 21

        f4 = self.rrb_d_4(f4)
        res2 = self.cab4([res2, f4])
        res2 = self.rrb_u_4(res2)

        f3 = self.rrb_d_3(f3)
        res2 = self.cab3([self.upsample(res2), f3])
        res2 =self.rrb_u_3(res2)

        f2 = self.rrb_d_2(f2)
        res2 = self.cab2([self.upsample(res2), f2])
        res2 =self.rrb_u_2(res2)

        f1 = self.rrb_d_1(f1)
        res2 = self.cab1([self.upsample(res2), f1])
        res2 = self.rrb_u_1(res2)

        return res1, res2
# 随机生成输入数据
image = torch.randn(1, 3, 512, 512)
# 定义网络
net = DFN(21)
net.eval()
# 前向传播
res1, res2 = net(image)
# 打印输出大小
print('-----'*5)
print(res1.size(), res2.size())
print('-----'*5)

在这里插入图片描述

from PIL import Image
import torch.nn as nn
from torchvision import models
import torch


resnet101 = models.resnet101(pretrained=False)


class RRB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(RRB, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        res = self.conv2(x)
        res = self.bn(res)
        res = self.relu(res)
        res = self.conv3(res)
        return self.relu(x + res)


class CAB(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(CAB, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.sigmod = nn.Sigmoid()

    def forward(self, x):
        x1, x2 = x  # high, low
        x = torch.cat([x1, x2], dim=1)
        x = self.global_pooling(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmod(x)
        x2 = x * x2
        res = x2 + x1
        return res


class DFN(nn.Module):
    def __init__(self, num_class=21):
        super(DFN, self).__init__()
        self.num_class = num_class
        self.layer0 = nn.Sequential(resnet101.conv1, resnet101.bn1, resnet101.relu)
        self.layer1 = nn.Sequential(resnet101.maxpool, resnet101.layer1)
        self.layer2 = resnet101.layer2
        self.layer3 = resnet101.layer3
        self.layer4 = resnet101.layer4

        # this is for smooth network
        self.out_conv = nn.Conv2d(2048, self.num_class, kernel_size=1, stride=1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.cab1 = CAB(self.num_class*2, self.num_class)
        self.cab2 = CAB(self.num_class*2, self.num_class)
        self.cab3 = CAB(self.num_class*2, self.num_class)
        self.cab4 = CAB(self.num_class*2, self.num_class)

        self.rrb_d_1 = RRB(256, self.num_class)
        self.rrb_d_2 = RRB(512, self.num_class)
        self.rrb_d_3 = RRB(1024, self.num_class)
        self.rrb_d_4 = RRB(2048, self.num_class)

        self.upsample = nn.Upsample(scale_factor=2,mode="bilinear")
        self.upsample_4 = nn.Upsample(scale_factor=4, mode="bilinear")
        self.upsample_8 = nn.Upsample(scale_factor=8, mode="bilinear")
        self.upsample_16 = nn.Upsample(scale_factor=16, mode="bilinear")
        self.upsample_32 = nn.Upsample(scale_factor=32, mode="bilinear")


        self.rrb_u_4 = RRB(self.num_class,self.num_class)
        self.rrb_u_3 = RRB(self.num_class,self.num_class)
        self.rrb_u_2 = RRB(self.num_class,self.num_class)
        self.rrb_u_1 = RRB(self.num_class,self.num_class)

        # this is for boarder net work
        self.rrb_db_1 = RRB(256, self.num_class)
        self.rrb_db_2 = RRB(512, self.num_class)
        self.rrb_db_3 = RRB(1024, self.num_class)
        self.rrb_db_4 = RRB(2048, self.num_class)

        self.rrb_trans_1 = RRB(self.num_class,self.num_class)
        self.rrb_trans_2 = RRB(self.num_class,self.num_class)
        self.rrb_trans_3 = RRB(self.num_class,self.num_class)

    def forward(self, x):
        f0 = self.layer0(x)  # 256, 256, 64
        f1 = self.layer1(f0)  # 128, 128, 256
        f2 = self.layer2(f1)  # 64, 64, 512
        f3 = self.layer3(f2)  # 32, 32, 1024
        f4 = self.layer4(f3)  # 16, 16, 2048

        # for border network
        res1 = self.rrb_db_1( self.upsample_4(f1)) # 512/128 = 4
        res1 = self.rrb_trans_1(res1 + self.upsample_8(self.rrb_db_2(f2))) # 512/64 = 8
        res1 = self.rrb_trans_2(res1 + self.upsample_16(self.rrb_db_3(f3)))  # 512/32 = 16
        res1 = self.rrb_trans_3(res1 + self.upsample_32(self.rrb_db_4(f4)))  # 512/16 = 32  # 倍数多大,效果堪忧

        # for smooth network
        res2 = self.out_conv(f4)    # 16, 16, 21
        res2 = self.global_pool(res2)  #
        res2 = nn.Upsample(size=f4.size()[2:],mode="nearest")(res2)     # 16, 16, 21

        f4 = self.rrb_d_4(f4)
        res2 = self.cab4([res2, f4])
        res2 = self.rrb_u_4(res2)

        f3 = self.rrb_d_3(f3)
        res2 = self.cab3([self.upsample(res2), f3])
        res2 =self.rrb_u_3(res2)

        f2 = self.rrb_d_2(f2)
        res2 = self.cab2([self.upsample(res2), f2])
        res2 =self.rrb_u_2(res2)

        f1 = self.rrb_d_1(self.upsample_4(f1))  # 最后强行扩展
        res2 = self.cab1([self.upsample_8(res2), f1])
        res2 = self.rrb_u_1(res2)

        return res1, res2


if __name__ == "__main__":
    # 随机生成输入数据
    image = torch.randn(1, 3, 512, 512)
    # 定义网络
    net = DFN(2)
    net.eval()
    # 前向传播
    res1, res2 = net(image)
    # 打印输出大小
    print('-----' * 5)
    print(res1.size(), res2.size())
    print('-----' * 5)

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值