DeformableAttention的原理解读和源码实现

本专栏主要是深度学习/自动驾驶相关的源码实现,获取全套代码请参考

原理

目前流行3D转2DBEV方案的都绕不开的transfomer变体-DeformableAttention.
在这里插入图片描述
传统transformer注意力关注全局特征,速度慢.而DeformableAttention注意力模块只关注一个目标周围的一小部分的关键采样点特征.原来的DETR需要很多个 epoch 才能找到特征,在Deformable DTER中可以更快,据说1/10的耗时。
原理:以DETR3D的做法为例.

第一步看看输入:

定义一个shape为(900,256)的query,代表900和目标,每个目标256维查询信息.
定义一个query_pos shape同query.
定义一个shape为(900,3)的reference_points,作为目标参考点.
输入为:pts_feats(1,43054,256),多尺度flatten结果,
多尺度特征图尺寸记录:spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])
特征图在pts_feats起点记录:level_start_index:([ 0, 32400, 40500, 42525])
可自行验算下.

第二步,准备工作:

pts_feats reshape为(1,43054,8,32)

value = value.view(bs, num_value, self.num_heads, -1)

生成参考点的偏移量

query经过self.sampling_offsets线性映射再reshape输出:
sampling_offsets(torch.Size([1, 900, 8, 4, 4, 2]))
其中8是多头数量,4是特征层数, 4是采样点数, 2是采样点xy两个维度.意思是8次在4层特征图上分别采样4个点,这844个点的xy方向的偏移量.

生成参考点的权重

query经过self.attention_weights线性映射再reshape输出:
attention_weights(torch.Size([1, 900, 8, 4, 4]))
对应上述点的权重.

生成参考点

reference_points加上参考点的偏移量生成,真正的参考点.

sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets

sampling_locations(torch.Size([1, 900, 8, 4, 4, 2]))

说白就是,就是定义一个query_embed,它生成自己即将要去采样的点位置和采样点权重.

第三步,工作:

输入:
value shape(torch.Size([b,43054,8,32]))
sampling_locations(torch.Size([b, 900, 8, 4, 4, 2]))
attention_weights(torch.Size([b, 900, 8, 4, 4]))
spatial_shapes:([[180, 180],[ 90, 90],[ 45, 45],[ 23, 23]])

value 根据spatial_shapes分解出各个level:
[torch.Size([b,180180,8,32],torch.Size([b,9090,8,32])),torch.Size([b,4545,8,32])),torch.Size([b,2323,8,32]))]
reshape为正常图像torch.Size([b*8,32,180,180]

sampling_locations原本为采样点位置,范围为[0,1),为了适应F.grid_sample采样函数的用法,调整为[-1,1)分布,
调用F.grid_sample对每一层特征进行采样,输入value为torch.Size([b8,32,level_h,level_w]),采样点为sampling_grid:torch.Size([b8,900,4,2])
则输出为sampling_value:torch.Size([b8,32,900,4])
意思是,900个query在特征图(32,level_h,level_w)中各采样4个点,采样结果为900个对应的4个通道为32的像素特征.
将4层采样结果sampling_value拍在一起torch.Size([b
8,32,900,4*4])

attention_weights变成相同形式(torch.Size([b8, 1,900, 44])),然后对16个采样特征进行加权求和输出outputtorch.Size([b,32*8,900]).后续交给FFN对多头特征进行全连接融合.

源码

import torch
import torch.nn.functional as F
import torch.nn as nn


def multi_scale_deformable_attn_pytorch(value, spatial_shapes, sampling_locations, attention_weights):
    batch, _, num_head, embeding_dim_perhead = value.shape
    _, query_size, _, level_num, sample_num, _ = sampling_locations.shape
    split_list = []
    for h, w in spatial_shapes:
        split_list.append(int(h * w))
    value_list = value.split(split_size=tuple(split_list), dim=1)
    # [0,1)分布变成 [-1,1)分布,因为要调用F.grid_sample函数
    sampling_grid = 2 * sampling_locations - 1
    output_list = []
    for level_id, (h, w) in enumerate(spatial_shapes):
        h = int(h)
        w = int(w)
        # batch, value_len, num_head, embeding_dim_perhead
        # batch, num_head, embeding_dim_perhead, value_len
        # batch*num_head, embeding_dim_perhead, h, w
        value_l = value_list[level_id].permute(0, 2, 3, 1).view(batch * num_head, embeding_dim_perhead, h, w)
        # batch,query_size,num_head,level_num,sample_num,2
        # batch,query_size,num_head,sample_num,2
        # batch,num_head,query_size,sample_num,2
        # batch*num_head,query_size,sample_num,2
        sampling_grid_l = sampling_grid[:, :, :, level_id, :, :].permute(0, 2, 1, 3, 4).view(batch * num_head,
                                                                                             query_size, sample_num, 2)
        # batch*num_head embeding_dim,,query_size, sample_num
        output = F.grid_sample(input=value_l,
                               grid=sampling_grid_l,
                               mode='bilinear',
                               padding_mode='zeros',
                               align_corners=False)
        output_list.append(output)
    # batch*num_head, embeding_dim_perhead,query_size, level_num, sample_num
    outputs = torch.stack(output_list, dim=-2)
    # batch,query_size,num_head,level_num,sample_num
    # batch,num_head,query_size,level_num,sample_num
    # batch*num_head,1,query_size,level_num,sample_num
    attention_weights = attention_weights.permute(0, 2, 1, 3, 4).view(batch * num_head, 1, query_size, level_num,
                                                                      sample_num)
    outputs = outputs * attention_weights
    # batch*num_head, embeding_dim_perhead,query_size
    # batch,num_head, embeding_dim_perhead,query_size
    # batch,query_size,num_head, embeding_dim_perhead
    # batch,query_size,num_head*embeding_dim_perhead
    outputs = outputs.sum(-1).sum(-1).view(batch, num_head, embeding_dim_perhead, query_size).permute(0, 3, 1, 2). \
        view(batch, query_size, num_head * embeding_dim_perhead)
    return outputs.contiguous()


if __name__ == '__main__':
    batch = 1
    num_head = 8
    embeding_dim = 256
    query_size = 900
    spatial_shapes = torch.Tensor([[180, 180], [90, 90], [45, 45], [23, 23]])
    value_len = (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum().int()
    value = torch.rand(size=(batch, value_len, embeding_dim))
    query_embeding = torch.rand(size=(batch, query_size, embeding_dim * 2 + 3))
    query = query_embeding[..., :embeding_dim]
    query_pos = query_embeding[..., embeding_dim:2 * embeding_dim]
    reference_poins = query_embeding[..., 2 * embeding_dim:]
    # 讨论1:在deformale-att中这个query并不会和value交互生成att-weights,att-weights只和query有关,
    # 也就是推理过程att-weights(包括sampling_locations)是固定的.
    # 据作者解释这是因为采用前者的方式计算的attention权重存在退化问题,
    # 即最后得到的attention权重与并没有随key的变化而变化。
    # 因此,这两种计算attention权重的方式最终得到的结果相当,
    # 而后者耗时更短、计算代价更小,所以作者选择直接对query做projection得到attention权重。
    # 讨论2:在query固定情况下,第一个layer的att-weights无法改变,
    # 但是第二个layer的query与value有关,att-weights则会发生变化.so the self-att in frist layer is not nesscerary
    level_num = 4
    sample_num = 4
    sampling_offsets_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num * 2)
    sampling_offsets = sampling_offsets_net(query).view(batch, query_size, num_head, level_num, sample_num, 2)
    sampling_location = reference_poins[:, :, None, None, None, :2] + sampling_offsets
    attention_weights_net = nn.Linear(in_features=embeding_dim, out_features=num_head * level_num * sample_num)
    attention_weights = attention_weights_net(query).view(batch, query_size, num_head, level_num * sample_num)
    attention_weights = attention_weights.softmax(dim=-1).view(batch, query_size, num_head, level_num,
                                                               sample_num)  # sum of 16 points weight is equal to 1
    embeding_dim_perhead = embeding_dim // num_head
    value = value.view(batch, value_len, num_head, -1)

    output = multi_scale_deformable_attn_pytorch(
        value, spatial_shapes, sampling_location, attention_weights)
    pass

如需获取全套代码请参考

  • 22
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Attention is all you

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值