源代码
def forward(self, f_e, pos_emb, bboxes):
bs, _, h, w = f_e.size()
# extract the shape features or objectness
if not self.zero_shot:
box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
bs, -1, self.kernel_dim ** 2, self.emb_dim
).flatten(1, 2).transpose(0, 1)
else:
shape_or_objectness = self.shape_or_objectness.expand(
bs, -1, -1, -1
).flatten(1, 2).transpose(0, 1)
# if not zero shot add appearance
if not self.zero_shot:
# reshape bboxes into the format suitable for roi_align
bboxes = torch.cat([
torch.arange(
bs, requires_grad=False
).to(bboxes.device).repeat_interleave(self.num_objects).reshape(-1, 1),
bboxes.flatten(0, 1),
], dim=1)
appearance = roi_align(
f_e,
boxes=bboxes, output_size=self.kernel_dim,
spatial_scale=1.0 / self.reduction, aligned=True
).permute(0, 2, 3, 1).reshape(
bs, self.num_objects * self.kernel_dim ** 2, -1
).transpose(0, 1)
else:
appearance = None
query_pos_emb = self.pos_emb(
bs, self.kernel_dim, self.kernel_dim, f_e.device
).flatten(2).permute(2, 0, 1).repeat(self.num_objects, 1, 1)
if self.num_iterative_steps > 0:
memory = f_e.flatten(2).permute(2, 0, 1)
all_prototypes = self.iterative_adaptation(
shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
)
else:
if shape_or_objectness is not None and appearance is not None:
all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
else:
all_prototypes = (
shape_or_objectness if shape_or_objectness is not None else appearance
).unsqueeze(0)
return all_prototypes
注释代码
def forward(self, f_e, pos_emb, bboxes):
'''
f_e(编码后的图像特征)
pos_emb(位置嵌入)
bboxes(边界框)
'''
# 获取图像特征的尺寸信息
bs, _, h, w = f_e.size()
# extract the shape features or objectness
# 提取形状特征或对象显著性(objectness)
if not self.zero_shot:
# 非零样本情况下,计算边界框的宽度和高度
box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] # 宽度
box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] # 高度
# 将形状信息通过全连接网络转换为特征表示
shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
bs, -1, self.kernel_dim ** 2, self.emb_dim
).flatten(1, 2).transpose(0, 1)
else:
shape_or_objectness = self.shape_or_objectness.expand(
bs, -1, -1, -1
).flatten(1, 2).transpose(0, 1)
# if not zero shot add appearance
# 如果不是零样本学习场景,则添加外观特征
# 当处于非零样本学习场景时,代码通过roi_align操作提取边界框内的特征,这些特征代表了对象的外观信息。
# roi_align操作从编码后的图像特征f_e中,根据提供的边界框bboxes提取特征,生成与对象形状相关的特征图。
# 通过permute和reshape操作调整提取的特征的形状,以便于与形状特征或其他处理步骤融合。
# 如果处于零样本学习场景,则不进行外观特征的提取,appearance被设置为None,这可能是因为在零样本场景下没有足够的样本来指导外观特征的提取。
if not self.zero_shot:
# reshape bboxes into the format suitable for roi_align
# 将边界框bboxes重塑为适用于roi_align的格式
# torch.arange生成从0到bs-1的整数序列,表示每个样本的索引
# requires_grad=False表示这些索引不需要计算梯度
# to(bboxes.device)将索引移动到bboxes所在的设备(GPU或CPU)
# repeat_interleave(self.num_objects)将每个索引重复num_objects次,以匹配样本数量
# reshape(-1, 1)将重复后的索引重塑为(-1, 1)的形状
# torch.cat沿着指定的维度(这里是dim=1)连接张量
bboxes = torch.cat([
torch.arange(
bs, requires_grad=False
).to(bboxes.device).repeat_interleave(self.num_objects).reshape(-1, 1),
bboxes.flatten(0, 1),
], dim=1)
# 使用roi_align从特征图f_e中提取与边界框对应的特征
# roi_align是一种池化操作,用于从特征图中提取感兴趣区域(bounding box)的特征
# boxes=bboxes传入包含边界框的张量
# output_size=self.kernel_dim指定输出特征图的大小
# spatial_scale=1.0 / self.reduction用于控制池化的比例,与reduction参数成反比
# aligned=True表示使用对齐的ROI池化,可以更好地处理边界框的边界
# 调整提取的外观特征的形状以适应后续操作
# permute(0, 2, 3, 1)重新排列张量的维度,将特征图的维度移到最前面
# reshape(bs, self.num_objects * self.kernel_dim ** 2, -1)将特征图展平为二维
# transpose(0, 1)交换第一个和第二个维度,以匹配期望的输入格式
appearance = roi_align(
f_e,
boxes=bboxes, output_size=self.kernel_dim,
spatial_scale=1.0 / self.reduction, aligned=True
).permute(0, 2, 3, 1).reshape(
bs, self.num_objects * self.kernel_dim ** 2, -1
).transpose(0, 1)
else:
# 如果是零样本学习场景,不提取外观特征,appearance设置为None
appearance = None
# 负责生成查询位置嵌入(query positional embedding)并根据迭代适应模块处理输入特征
# 生成查询位置嵌入
# self.pos_emb是一个用于生成位置嵌入的模块,它接收批量大小bs、核尺寸kernel_dim、核尺寸kernel_dim和设备f_e.device作为参数
# .flatten(2)将除了最后一个维度外的所有维度展平
# .permute(2, 0, 1)重新排列维度,将位置嵌入调整为正确的形状以用于后续操作
# .repeat复制num_objects次,以匹配样本数量
query_pos_emb = self.pos_emb(
bs, self.kernel_dim, self.kernel_dim, f_e.device
).flatten(2).permute(2, 0, 1).repeat(self.num_objects, 1, 1)
# 如果迭代适应模块的迭代步数大于0,则调用该模块
if self.num_iterative_steps > 0:
# 将编码后的图像特征f_e展平并重新排列维度,以匹配迭代适应模块的输入要求
memory = f_e.flatten(2).permute(2, 0, 1)
# 调用迭代适应模块,传入形状或对象显著性特征、外观特征、内存特征、位置嵌入和查询位置嵌入
# 该模块将执行一系列迭代步骤来适应和改进特征表示
all_prototypes = self.iterative_adaptation(
shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
)
# 如果迭代适应模块的迭代步数为0,则执行以下操作
# 根据形状或对象显著性特征(shape_or_objectness)和外观特征(appearance)生成对象原型(all_prototypes)
else:
# 检查形状或对象显著性特征和外观特征是否都不为None
# 如果两者都存在,将它们相加并扩展维度以形成对象原型
if shape_or_objectness is not None and appearance is not None:
# 将形状或对象显著性特征和外观特征相加,得到综合的特征表示
# .unsqueeze(0)在第一个维度(批次维度)上扩展张量,从(N, C)变为(1, N, C)
all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
# 如果其中之一为None(可能在零样本学习场景中),选择非None的特征
# 并扩展维度形成对象原型
else:
# 选择shape_or_objectness或appearance中非None的特征
# 如果shape_or_objectness为None,则选择appearance,反之亦然
all_prototypes = (
shape_or_objectness if shape_or_objectness is not None else appearance
).unsqueeze(0)
# 返回最终形成的对象原型张量
return all_prototypes