简述
原文链接:Object-Contextual Representations for Semantic Segmentation
可参考代码:pytorch版
个人理解:作者借鉴了self_attention机制,在Q、K、V三个参数的输入和输出做出了相应的计算调整,用新颖的方式表征注意力机制的计算过程,但本质还是自注意力机制巧妙运用,论文具体的交代描述可自行查看原论文,或点击这里参考。
OCR模块理解
论文中对于该模块的叙述也仅仅提供了一张模型图。
咋一眼看上去,感觉,诶,有点东西奥!仔细一看,这特么什么玩意,只给一系列名称,也不具体说说每个名字代号里面包含了什么操作。唯一能看懂的就是两个loss,三个乘法器,一个concat。(我的习惯是看这类论文就先看模型结构图,勿喷),仔细看论文内容吧,也就把这张图理解个懵懵懂懂,直到我看到这位大佬的理解图,我他妈当场就直呼好家伙!!!传送门
大佬就是大佬,这图一看完,代码它就不直接出来了嘛!!!
TF2代码实现
OCR模块
import tensorflow.keras.backend as K
from tensorflow import keras
import tensorflow as tf
from edge_detect.Encoder_Edge import Encoder
def OCR_gather_head(PR, SOR):
PR = tf.reshape(PR, shape=[-1, PR.shape[1]*PR.shape[2], PR.shape[-1]]) # b hw c
SOR = tf.reshape(SOR, shape=[-1, SOR.shape[1]*SOR.shape[2], SOR.shape[-1]]) # b hw num_classes
SOR = tf.transpose(SOR, [0, 2, 1]) # b num_classes hw
SOR = K.softmax(SOR) # b num_classes hw
object_region_representations = tf.matmul(SOR, PR) # b num_classes c
object_region_representations = tf.expand_dims(object_region_representations, axis=-2) # b num_classes 1 c
# object_region_representations = tf.transpose(object_region_representations, [0, 1, 3, 2])
return object_region_representations
def OCR_DISTI_HEAD(PR, ORR):
query = keras.layers.Conv2D(filters=64, kernel_size=1, padding='same')(PR)
query = keras.layers.BatchNormalization()(query)
query = keras.layers.Activation('relu')(query)
query = tf.reshape(query, shape=[-1, query.shape[1]*query.shape[2], query.shape[-1]]) # b hw c1
key = keras.layers.Conv2D(filters=64, kernel_size=1, padding='same')(ORR)
key = keras.layers.BatchNormalization()(key)
key = keras.layers.Activation('relu')(key)
key = tf