GeneralizedRCNN 继承基类 nn.Module 。
class GeneralizedRCNN(nn.Module):
def __init__(self, backbone, rpn, roi_heads, transform):
"""
初始化函数,定义了 GeneralizedRCNN 类的属性
参数:
backbone: 主干网络,用于从输入图像中提取特征
rpn: 区域建议网络,用于生成图像中的候选区域
roi_heads: 感兴趣区域头部,用于对候选区域进行分类和定位
transform: 对输入图像进行转换的对象,用于处理输入图像使其适合于模型的输入要求
"""
super(GeneralizedRCNN, self).__init__()
self.transform = transform
self.backbone = backbone
self.rpn = rpn
self.roi_heads = roi_heads
def forward(self, images, targets=None):
# 创建一个空的列表,用于保存每个图像的高度和宽度信息
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
# 遍历输入的图像列表
for img in images:
# 获取图像的高度和宽度
val = img.shape[-2:]
# 断言确保图像的形状是二维的
assert len(val) == 2
# 将图像的高度和宽度信息添加到列表中
original_image_sizes.append((val[0], val[1]))
images, targets = self.transform(images, targets)
# 使用主干网络从图像中提取特征
features = self.backbone(images.tensors)
# 如果特征是一个张量,则将其转换成 OrderedDict,键为'0',以便与后续的代码兼容
if isinstance(features, torch.Tensor):
features = OrderedDict([('0', features)])
# 使用区域建议网络获取候选区域及其损失
proposals, proposal_losses = self.rpn(images, features, targets)
# 使用感兴趣区域头部进行目标检测,并计算相应的损失
detections, detector_losses = self.roi_heads(features, proposals, images.image_sizes, targets)
# 使用图像处理后处理检测结果,确保它们符合原始图像大小
detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)
# 整合并返回检测和建议的损失
losses = {}
losses.update(detector_losses)
losses.update(proposal_losses)
# 返回损失和检测结果
return (losses, detections)