Conditional DETR解读---带anchor的DETR

DETR存在的问题

1.收敛速度慢

2.对小目标物体检测效果不好,因为transformer计算量大,受限于计算规模,CNN提取特征时只采取了最后一层特征,没有用FPN等结构。所以对于小目标检测效果不好。

论文主要观点

  • 通过对DETRdecoder中的attentionmap进行可视化,发现query查询到的区域都是物体的extremity末端区域。所以论文认为attention尝试找到物体的边界区域。

  • 论文中认为DETRtransofmer结构中的信息主要可以分为两部分,一部分是与图像的特征(颜色纹理等)相关的信息,称为content,比如encoder或decoder的输出信息。另一部分是代表空间上的信息,称为spatial,比如position embedding等。

  • detr中的CNN与encoder只涉及图像特征向量提取;decoder中的self-attn只涉及query之间的交互去重;所以收敛慢的最可能原因发生在cross attn

  • Cross attention中的K包含encoder输出信息(content key Ck)与position embedding(spatial Key Pk),Q包含self attention的输出(content query Cq)和object query(spatial query Pq)信息。论文中发现去掉cross attention中的object基本不掉点,所以收敛慢很可能是content query难学习导致的。

  • 提出了reference point的概念,为每个query设定一个检测范围,使得匹配更加稳定,加快了收敛

  • 原始detr混合两者学习,使得content query难学习。所以将content与spatial进行解耦

在这里插入图片描述

变为

在这里插入图片描述

网络结构

在这里插入图片描述

对于object query生成了一个2D坐标embedding(上图中的s),用于限定当前query的预测范围。最终decoder的输出的是相对与s的偏移量

bbox回归输出

在这里插入图片描述

其中f是decoer的输出,S表示x,y的坐标。最终b是[x,y,w,h]的向量。

classifier分类输出

在这里插入图片描述

f是decoder的输出,输出每个候选框的类别

decoder Pq生成:

提出了reference point的概念,即图中的s,是一个2d的坐标(q_num,B,2),由object queries经过一个线性层生成,代表了每个query的预测范围。

s经过sigmoid和position embedding后(图中的Ps),跟FFN(decoder embedding)(即图中的T)做内积。得到空间特征Pq

在这里插入图片描述

在这里插入图片描述

代码spatial query这一部分的实现:

# query_pos [num_query,batch,d_model]
# reference_points_before_sigmoid [num_query,batch,2]  从query预测一个坐标,代表了这个query预测的大概范围
reference_points_before_sigmoid = self.ref_point_head(query_pos)    # [num_queries, batch_size, 2]
reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1)
for layer_id, layer in enumerate(self.layers):
    # 图里的s,代表了query的预测大概范围
    obj_center = reference_points[..., :2].transpose(0, 1)      # [num_queries, batch_size, 2]

    # For the first decoder layer, we do not apply transformation over p_s
    ## pos_transformation代表图里的T,表示decoder embedding的特征经过ffn后其实得到的是相对于s的偏移量
    if layer_id == 0:
        pos_transformation = 1
    else:
        pos_transformation = self.query_scale(output)

    # get sine embedding for the query vector
    query_sine_embed = gen_sineembed_for_position(obj_center)     
    # apply transformation
    # 最终的Pq,代表空间特征信息
    query_sine_embed = query_sine_embed * pos_transformation
    output = layer(output, memory, tgt_mask=tgt_mask,
                   memory_mask=memory_mask,
                   tgt_key_padding_mask=tgt_key_padding_mask,
                   memory_key_padding_mask=memory_key_padding_mask,
                   pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
                   is_first=(layer_id == 0))

decoder中cross attention的实现


# ========== Begin of Cross-Attention =============
# Apply projections here
# shape: num_queries x batch_size x 256
q_content = self.ca_qcontent_proj(tgt)
k_content = self.ca_kcontent_proj(memory)
v = self.ca_v_proj(memory)

num_queries, bs, n_model = q_content.shape
hw, _, _ = k_content.shape

# k的位置编码
k_pos = self.ca_kpos_proj(pos)

# For the first decoder layer, we concatenate the positional embedding predicted from 
# the object query (the positional embedding) into the original query (key) in DETR.
if is_first:
    q_pos = self.ca_qpos_proj(query_pos)
    q = q_content + q_pos
    k = k_content + k_pos
else:
    q = q_content
    k = k_content

q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)
# decoder embedding cat spatial query
q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
k = k.view(hw, bs, self.nhead, n_model//self.nhead)
# encoder embdeding cat position embedding
k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)

tgt2 = self.cross_attn(query=q,
                           key=k,
                           value=v, attn_mask=memory_mask,
                           key_padding_mask=memory_key_padding_mask)[0]               
# ========== End of Cross-Attention =============

head的实现

# hs代表decoder embedding,reference代表s(reference point)
hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])
reference_before_sigmoid = inverse_sigmoid(reference)
outputs_coords = []
for lvl in range(hs.shape[0]):
    # 回归head hs输出相对于 reference的偏移量,得到检测框
    tmp = self.bbox_embed(hs[lvl])
    tmp[..., :2] += reference_before_sigmoid
    outputs_coord = tmp.sigmoid()
    outputs_coords.append(outputs_coord)
outputs_coord = torch.stack(outputs_coords)
#分类head,hs输出分类结果
outputs_class = self.class_embed(hs)

总结思考

实际上conditional DETR有点像transfoermer版本的faster-RCNN。将特征信息与空间信息进行了解耦。reference point像anchor的概念,让网络自己为每个query设定一个anchor范围,从而使得二分匹配更加问题,所以加快了网络的收敛

作者论文解读:https://zhuanlan.zhihu.com/p/401916664
公式解释得更加详细

  • 19
    点赞
  • 30
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值