query_embed是如何赋值或初始化的:
1.首先看搭建的整个网络(detr.py):
def build(args):
# 搭建backbone resnet + PositionEmbeddingSine
backbone = build_backbone(args)
# 搭建transformer
transformer = build_transformer(args)
# 搭建整个DETR模型
model = DETR(
backbone,
transformer,
num_classes=num_classes,
num_queries=args.num_queries,
aux_loss=args.aux_loss,
)
2.在DETR结构上query_embed首次定义。(detr.py)
创建为一个可学习的参数。query_embed是一个100x256 的可学习参数。query_embed的作用是为解码器提供初始的查询信息。通过学习这些查询向量,模型可以在解码过程中与编码器的输出进行交互,以生成最终的预测结果。
class DETR(nn.Module):
""" This is the DETR module that performs object detection """
def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
super().__init__()
self.num_queries = num_queries
self.transformer = transformer
# self.query_embed 类似于传统目标检测里面的anchor 这里设置了100个 [100,256]
self.query_embed = nn.Embedding(num_queries, hidden_dim)
3.看下transformer的搭建,(transformer.py)
将位置编码和查询(query_embed)嵌入调整为与输入数据(src
)相匹配的形状,并生成一个tgt
并且还有个要注意的是
query_pos=query_embed,对象查询他给他重新命名了。
def build_transformer(args):
return Transformer(
...
)
class Transformer(nn.Module):
def forward(self, src, mask, query_embed, pos_embed):
# flatten NxCxHxW to HWxNxC
bs, c, h, w = src.shape
src = src.flatten(2).permute(2, 0, 1)
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
mask = mask.flatten(1)
tgt = torch.zeros_like(query_embed)
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
pos=pos_embed, query_pos=query_embed)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
4.看下在decoder的定义:
class TransformerDecoder(nn.Module):
def __init__(self, decoder_layer, num_layers):
super().__init__()
self.layers = _get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
def forward(self, tgt, memory,
# 目标序列的掩码,用于指定解码器在生成目标序列时应该注意哪些位置。确保在生成每个目标时,解码器不会访问后面尚未生成的目标。
tgt_mask: Optional[Tensor] = None,
# 用于指定在计算注意力时应该忽略哪些位置。
memory_mask: Optional[Tensor] = None,
# 目标序列的填充掩码,用于指定目标序列中哪些位置是填充值。这些位置的注意力权重将被设置为负无穷大,以排除填充值对注意力计算的影响。
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
output = tgt
for layer in self.layers:
output = layer(output, memory, tgt_mask=tgt_mask,
memory_mask=memory_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
pos=pos, query_pos=query_pos)
return output.unsqueeze(0)
class TransformerDecoderLayer(nn.Module):
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
activation="relu", normalize_before=False):
super().__init__()
self.normalize_before = normalize_before
...
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos
def forward(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
if self.normalize_before:
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
5.随机取一个forward_pre分析:执行解码器层的前向传播。
在这里v直接使用了tgt2,q和k分别是加上了query_pos的结果,
结合结构图一起看,在第一层的时候结构图是写的V是来自tgt2,是全为0的一个向量,Q和K的结果不一样,是加上了query_pos的结果,而query_pos的值是随机初始化的,和结构图不太一样,所以这个结构图不太准确。即v是全0,qk不是全0。
def forward_pre(self, tgt, memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
# v取的0值,q和k取的是v加上原query_pos的值
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask)[0]
# 将自注意力机制的结果与原始输入进行相加,并施加随机失活操作,以增强模型的表示能力和泛化性能。
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
# 在此步骤时,q和k就不相同了,q是tgt2和query_pos结合的结果,而k变成了encoder输出的结果和位置编码pos的结合结果,v变成了encoder的输出memory,
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
key=self.with_pos_embed(memory, pos),
value=memory, attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
# 和网络结构图中大致一样,这把norm提前了,先算的norm,再计算的ffn,最后算的add。
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt