代码地址:
https://github.com/Pointcept/Pointcept/blob/main/pointcept/models/point_transformer_v2/point_transformer_v2m1_origin.py
class GroupedVectorAttention(nn.Module):
def __init__(
self,
embed_channels,
groups,
attn_drop_rate=0.0,
qkv_bias=True,
pe_multiplier=False,
pe_bias=True,
):
super(GroupedVectorAttention, self).__init__()
self.embed_channels = embed_channels
self.groups = groups
assert embed_channels % groups == 0
self.attn_drop_rate = attn_drop_rate
self.qkv_bias = qkv_bias
self.pe_multiplier = pe_multiplier
self.pe_bias = pe_bias
self.linear_q = nn.Sequential(
nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
)
self.linear_k = nn.Sequential(
nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
)
self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)
if self.pe_multiplier:
self.linear_p_multiplier = nn.Sequential(
nn.Linear(3, embed_channels),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
nn.Linear(embed_channels, embed_channels),
)
if self.pe_bias:
self.linear_p_bias = nn.Sequential(
nn.Linear(3, embed_channels),
PointBatchNorm(embed_channels),
nn.ReLU(inplace=True),
nn.Linear(embed_channels, embed_channels),
)
self.weight_encoding = nn.Sequential(
GroupedLinear(embed_channels, groups, groups),
PointBatchNorm(groups),
nn.ReLU(inplace=True),
nn.Linear(groups, groups),
)
self.softmax = nn.Softmax(dim=1)
self.attn_drop = nn.Dropout(attn_drop_rate)
def forward(self, feat, coord, reference_index):
query, key, value = (
self.linear_q(feat),
self.linear_k(feat),
self.linear_v(feat),
)
key = pointops.grouping(reference_index, key, coord, with_xyz=True)
value = pointops.grouping(reference_index, value, coord, with_xyz=False)
pos, key = key[:, :, 0:3], key[:, :, 3:]
relation_qk = key - query.unsqueeze(1)
if self.pe_multiplier:
pem = self.linear_p_multiplier(pos)
relation_qk = relation_qk * pem
if self.pe_bias:
peb = self.linear_p_bias(pos)
relation_qk = relation_qk + peb
value = value + peb
weight = self.weight_encoding(relation_qk)
weight = self.attn_drop(self.softmax(weight))
mask = torch.sign(reference_index + 1)
weight = torch.einsum("n s g, n s -> n s g", weight, mask)
value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups)
feat = torch.einsum("n s g i, n s g -> n g i", value, weight)
feat = einops.rearrange(feat, "n g i -> n (g i)")
return feat
- 分组的向量注意力
self.weight_encoding = nn.Sequential(
GroupedLinear(embed_channels, groups, groups),
PointBatchNorm(groups),
nn.ReLU(inplace=True),
nn.Linear(groups, groups),
)
对channel进行分组,每个组共用attention
- 位置编码加入relation的计算
两个选项分别是self.pe_multiplier和self.pe_bias,以不同方式把位置编码融入relation的计算。