【深度学习】CenterNet-better-plus代码解读

本文详细解读了CenterNet-better-plus的代码实现,重点探讨了detectron2架构及其ground truth(gt)生成过程。通过源代码分析,揭示了模型处理inputs的细节,包括gt_bbox和gt_classes等关键信息。
摘要由CSDN通过智能技术生成

CenterNet-better-plus代码解读

源代码 https://github.com/lbin/CenterNet-better-plus

detectron2架构

gt生成部分

centernet/centernet_gt.py

首先看一下经过处理的inputs
在这里插入图片描述
instance类里面包含了图片的gt_bbox、gt_classes等信息

import numpy as np
import torch


class CenterNetGT(object):
    @staticmethod
    def generate(config, batched_input):
        
        box_scale = 1 / config.MODEL.CENTERNET.DOWN_SCALE
        num_classes = config.MODEL.CENTERNET.NUM_CLASSES
        output_size = config.MODEL.CENTERNET.OUTPUT_SIZE #[128,128]
        min_overlap = config.MODEL.CENTERNET.MIN_OVERLAP # 0.7
        tensor_dim = config.MODEL.CENTERNET.TENSOR_DIM # 128

        scoremap_list, wh_list, reg_list, reg_mask_list, index_list = [[] for i in range(5)] #初始化gt列表
        
        for data in batched_input:
            # img_size = (data['height'], data['width'])

            bbox_dict = data["instances"].get_fields() # 字典形式保存了gt_bbox和gt_classes

            # init gt tensors
            gt_scoremap = torch.zeros(num_classes, *output_size) # torch.Size([80, 128, 128]) cpu
            gt_wh = torch.zeros(tensor_dim, 2) # torch.Size([128, 2])
            gt_reg = torch.zeros_like(gt_wh) # torch.Size([128, 2]) 中心点的偏移量
            reg_mask = torch.zeros(tensor_dim) # torch.Size([128])
            gt_index = torch.zeros(tensor_dim) # torch.Size([128])
            # pass

            boxes, classes = bbox_dict["gt_boxes"], bbox_dict["gt_classes"]
            num_boxes = boxes.tensor.shape[0]
            boxes.scale(box_scale, box_scale)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值