import torch
from torch import nn
class KeypointPostProcessor(nn.Module):
def __init__(self, keypointer=None):
super(KeypointPostProcessor, self).__init__()
self.keypointer = keypointer # xy_preds(x,1,y), end_scores
def forward(self, x, boxes):
mask_prob = x
scores = None
if self.keypointer:
mask_prob, scores = self.keypointer(x, boxes) # mask_prob和 boxes →→ mask_prob, scores
assert len(boxes) == 1, "Only non-batched inference supported for now"
boxes_per_image = [box.bbox.size(0) for box in boxes]
mask_prob = mask_prob.split(boxes_per_image, dim=0)
scores = scores.split(boxes_per_image, dim=0)
results = []
for prob, box, score in zip(mask_prob, boxes, scores):
bbox = BoxList(box.bbox, box.size, mode="xyxy")
for field in box
maskrcnn_benchmark理解记录——modeling\roi_heads\keypoint_head\inference.py
最新推荐文章于 2024-11-20 00:13:18 发布
本文详细记录了对maskrcnn_benchmark中modeling.roi_heads.keypoint_head模块的inference.py文件的理解,主要涉及关键点检测的推理流程,包括特征提取、关键点预测和后处理步骤。

最低0.47元/天 解锁文章
1万+

被折叠的 条评论
为什么被折叠?



