FarSeg基于遥感图像前后景不平衡优化的分割网络

前言

本文基于Foreground-Aware Relation Network for Geospatial Object Segmentation in
High Spatial Resolution Remote Sensing Imagery
这篇论文进行复现,本篇是基于HSR(高空间分辨)的遥感图像进行的优化,主要是解决遥感图像的前后景不均衡的问题。
在这里插入图片描述
如上图所示,本文说出了与自然场景相比,遥感图像的三大挑战
(1)图像大尺度变换
(2)前后景不均衡
(3)大的类内变化
在这里插入图片描述
本文的主要网络架构如图所示,分为了四个部分:
(a)Multi-Branch Encoder:主要是FPN+scene embedding branch(GAP实现)组成
(b)Foreground-Scene Relation:主要由下图结果得出
在这里插入图片描述
(c)Light-weight Decoder:反卷积,在叠加
(d)foreground-Aware Optimization ,提出了如下所示的叠加权重的loss-function:
在这里插入图片描述
ζ(t)代表了退火算法。
今天我们主要复现(a)(b)(c)

代码复现

Multi-Branch Encoder
这个模块比较简单,就是FPN+GAP的一个架构,这里就不多说了

    #########################Multi-Branch Encoder
    C2=tf.layers.conv2d(x,channel,3,strides=2, padding='same')
    C3=tf.layers.conv2d(C2,channel,3,strides=2, padding='same')
    C4=tf.layers.conv2d(C3,channel,3,strides=2, padding='same')
    C5=tf.layers.conv2d(C4,channel,3,strides=2, padding='same')
    P5=tf.layers.conv2d(C5,channel,1,strides=1, padding='same')
    P4=tf.layers.conv2d(C4,channel,1,strides=1, padding='same')+tf.layers.conv2d_transpose(P5,channel,3,strides=2, padding='same')
    P3=tf.layers.conv2d(C3,channel,1,strides=1, padding='same')+tf.layers.conv2d_transpose(P4,channel,3,strides=2, padding='same')
    P2=tf.layers.conv2d(C2,channel,1,strides=1, padding='same')+tf.layers.conv2d_transpose(P3,channel,3,strides=2, padding='same') 
    C6=tf.keras.layers.GlobalAvgPool2D()(C5)
    print(C6)

Foreground-Scene Relation
这个模块比较复杂,基于论文中的公式,我们详细的说明一下
首先根据公式:
在这里插入图片描述
得到u,其中C6就要经过一个1*1的卷积才能得到scene_embedding_branch(u),之后根据下面的两个公式得到Vi,并且依次和U相乘得到Ri。
在这里插入图片描述
在这里插入图片描述
最后一步根据公式:
在这里插入图片描述
其中RI经过sigmoid函数在于Vi经过BN和Relu的结果相乘。

    #########################Foreground-Scene Relation   
    C6=tf.reshape(C6,[-1,1,1,channel])
    U=tf.layers.conv2d(C6,channel,1,strides=1, padding='same')
    V5=tf.nn.relu(tf.layers.conv2d(P5,channel,1,strides=1, padding='same'))
    V4=tf.nn.relu(tf.layers.conv2d(P4,channel,1,strides=1, padding='same'))
    V3=tf.nn.relu(tf.layers.conv2d(P3,channel,1,strides=1, padding='same'))
    V2=tf.nn.relu(tf.layers.conv2d(P2,channel,1,strides=1, padding='same'))
    R5=V5*U
    R4=V4*U
    R3=V3*U
    R2=V2*U
    Z2=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V2,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R2)
    Z3=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V3,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R3)
    Z4=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V4,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R4)
    Z5=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V5,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R5)

最后Light-weight Decoder:就是如下图所示:
在这里插入图片描述
就是1个3*3的卷积,BN,relu,在家一个反卷积,不同大小反卷积的stride不同,分别为8,4,2

    Z3=tf.layers.conv2d_transpose(tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(Z3,channel,3,strides=1, padding='same'))),channel,3,strides=2, padding='same')
    Z4=tf.layers.conv2d_transpose(tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(Z4,channel,3,strides=1, padding='same'))),channel,3,strides=4, padding='same')
    Z5=tf.layers.conv2d_transpose(tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(Z5,channel,3,strides=1, padding='same'))),channel,3,strides=8, padding='same')
    Z=Z2+Z3+Z4+Z5
    print(Z)

代码实现中我们偷了个懒哈,直接扩增成相应的倍数。

完整代码

# -*- coding: utf-8 -*-
"""
Created on Tue Aug 18 11:34:42 2020

@author: surface
"""

import tensorflow as tf


x = tf.placeholder(tf.float32,[None, 224, 224, 3])#输入图片大小


def FarSeg(x,channel):
    #########################Multi-Branch Encoder
    C2=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(x,channel,3,strides=2, padding='same')))
    C3=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(C2,channel,3,strides=2, padding='same')))
    C4=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(C3,channel,3,strides=2, padding='same')))
    C5=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(C4,channel,3,strides=2, padding='same')))
    P5=tf.layers.conv2d(C5,channel,1,strides=1, padding='same')
    P4=tf.layers.conv2d(C4,channel,1,strides=1, padding='same')+tf.layers.conv2d_transpose(P5,channel,3,strides=2, padding='same')
    P3=tf.layers.conv2d(C3,channel,1,strides=1, padding='same')+tf.layers.conv2d_transpose(P4,channel,3,strides=2, padding='same')
    P2=tf.layers.conv2d(C2,channel,1,strides=1, padding='same')+tf.layers.conv2d_transpose(P3,channel,3,strides=2, padding='same') 
    C6=tf.keras.layers.GlobalAvgPool2D()(C5)
    #########################Foreground-Scene Relation   
    C6=tf.reshape(C6,[-1,1,1,channel])
    U=tf.layers.conv2d(C6,channel,1,strides=1, padding='same')
    V5=tf.nn.relu(tf.layers.conv2d(P5,channel,1,strides=1, padding='same'))
    V4=tf.nn.relu(tf.layers.conv2d(P4,channel,1,strides=1, padding='same'))
    V3=tf.nn.relu(tf.layers.conv2d(P3,channel,1,strides=1, padding='same'))
    V2=tf.nn.relu(tf.layers.conv2d(P2,channel,1,strides=1, padding='same'))
    R5=V5*U
    R4=V4*U
    R3=V3*U
    R2=V2*U
    Z2=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V2,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R2)
    Z3=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V3,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R3)
    Z4=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V4,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R4)
    Z5=tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(V5,channel,1,strides=1, padding='same')))*tf.nn.sigmoid(R5)
    print(Z2,Z3,Z4,Z5)
    #########################Light-weight Decoder       
    Z3=tf.layers.conv2d_transpose(tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(Z3,channel,3,strides=1, padding='same'))),channel,3,strides=2, padding='same')
    Z4=tf.layers.conv2d_transpose(tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(Z4,channel,3,strides=1, padding='same'))),channel,3,strides=4, padding='same')
    Z5=tf.layers.conv2d_transpose(tf.nn.relu(tf.layers.batch_normalization(tf.layers.conv2d(Z5,channel,3,strides=1, padding='same'))),channel,3,strides=8, padding='same')
    Z=Z2+Z3+Z4+Z5
    print(Z)
FarSeg(x,256)
  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值