def __init__(self):
super(ModelBuilder, self).__init__()
# build backbone
self.backbone = get_backbone(cfg.BACKBONE.TYPE,
**cfg.BACKBONE.KWARGS)
# build car head
self.car_head = CARHead(cfg, 256)
# build response map
self.attention = Graph_Attention_Union(256, 256)
在CARHead()中进行分类和回归子网的包围盒预测
在Graph_Attention_Union()中
class Graph_Attention_Union(nn.Module):
def __init__(self, in_channel, out_channel):
super(Graph_Attention_Union, self).__init__()
# search region nodes linear transformation
self.support = nn.Conv2d(in_channel, in_channel, 1, 1)
# target template nodes linear transformation
self.query = nn.Conv2d(in_channel, in_channel, 1, 1)
# linear transformation for message passing
self.g = nn.Sequential(
nn.Conv2d(in_channel, in_channel, 1, 1),
nn.BatchNorm2d(in_channel),
nn.ReLU(inplace=True),
)
# aggregated feature
self.fi = nn.Sequential(
nn.Conv2d(in_channel*2, out_channel, 1, 1),
nn.BatchNorm2d(out_channel),
nn.ReLU(inplace=True),
)
def forward(self, zf, xf):
# linear transformation
xf_trans = self.query(xf)
zf_trans = self.support(zf)
# linear transformation for message passing
xf_g = self.g(xf)
zf_g = self.g(zf)
# calculate similarity
shape_x = xf_trans.shape
shape_z = zf_trans.shape
zf_trans_plain = zf_trans.view(-1, shape_z[1], shape_z[2] * shape_z[3])
zf_g_plain = zf_g.view(-1, shape_z[1], shape_z[2] * shape_z[3]).permute(0, 2, 1)
xf_trans_plain = xf_trans.view(-1, shape_x[1], shape_x[2] * shape_x[3]).permute(0, 2, 1)
similar = torch.matmul(xf_trans_plain, zf_trans_plain)
similar = F.softmax(similar, dim=2)
以上代码对应于论文中
这里用softmax函数规划华是为了平衡发送到搜索区域的信息量。
embedding = torch.matmul(similar, zf_g_plain).permute(0, 2, 1)
embedding = embedding.view(-1, shape_x[1], shape_x[2], shape_x[3])
对应论文中
得到Gt中所有结点传递到Gs中的第i个节点的注意
# aggregated feature
output = torch.cat([embedding, xf_g], 1)
output = self.fi(output)
return output
对应论文中
将聚合特征与节点特征hi融合,以获得目标信息赋予的更强大的特征表示