CenterNet:Objects as Points代码解析(四) :CenterNet/src/lib/datasets/sample/cedet

在整个代码中cedet.py有两个,本文中的在lib/datasets下,作用是对用于任务cedt的数据进行进行处理,得到训练时用于前向传播的inp(input输入图像), 和训练时用于和输出预测值进行比较得到各项损失的hm, reg_mask, ind, wh地面真值。

# – coding:utf-8 –
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch.utils.data as data
import numpy as np
import torch
import json
import cv2
import os
from utils.image import flip, color_aug
from utils.image import get_affine_transform, affine_transform
from utils.image import gaussian_radius, draw_umich_gaussian, draw_msra_gaussian
from utils.image import draw_dense_reg
import math

class CTDetDataset(data.Dataset):
  def _coco_box_to_bbox(self, box):
  # box[0], box[1], box[2], box[3]分别代表xmin,ymin,w,h=====>>执行下面一句代码后,分别代表xmin,ymin,xmax,ymax
    bbox = np.array([box[0], box[1], box[0] + box[2], box[1] + box[3]],
                    dtype=np.float32)
    return bbox

  def _get_border(self, border, size):
    i = 1
    #如果图像宽(或高)小于等于2×boder,则i增大为2,返回border // i,
    #否则,如果图像宽(或高)大于 2×boder,则i不变,返回border
    while size - border // i <= border // i: #化简后为size <= 2×border
        i *= 2
    return border // i
  
  #这里我们可以得到我们输出参数(字典形式),分别是inp(input输入图像), hm, reg_mask, ind, wh。
  
  #??????第一个index为什么是12111,通过跳进来的语句enumerate(data_loader),第一个迭代应该是data_loader[0],所以index不应该是0吗???????????
  #因为dataloader中存储的训练数据是一个批量一个批量的(由main.py中的train_loader =torch.utils.data.DataLoader()的参数中有batch_size看出),
  所以枚举时,一次取出的是一个批量(即,data_loader[0]代表一个批量), __getitem__魔法方法一次执行一个批量里的一张图片,且批量是打乱数据后取得,
  所以第一次执行__getitem__处理的第一个批量的第一个图片索引不一定是0。由于使用了随机种子,所以这里的第一个索引一直是一样的。
  
  #该函数一次处理一张图片,对于enumerate()的一次迭代,一共循环执行batch_size次
  def __getitem__(self, index):
    img_id = self.images[index]
    file_name = self.coco.loadImgs(ids=[img_id])[0]['file_name']
    img_path = os.path.join(self.img_dir, file_name)
    ann_ids = self.coco.getAnnIds(imgIds=[img_id])#注释索引,是一个列表,其中元素个数是该图片中包含的标注对象的个数。
    #anns是一个列表,其元素是字典,每一个字典描述一个对象.
    anns = self.coco.loadAnns(ids=ann_ids)
    #num_objs为我们一张图片选取top中心点的数量(即类似起到NMS作用)。若图中标注的对象数量K小于self.max_objs,则该图片图片选取K个中心点的数量,否则该图片图片选取self.max_objs个中心点的数量。
    #len(anns)表示一张图片中标注对象的数量。anns是一个列表,其元素是字典,每一个字典描述一个对象,所以anns的长度len(anns)表示一张图片中标注对象的数量。
    num_objs = min(len(anns), self.max_objs)

    img = cv2.imread(img_path)
    #提取图片的原始尺寸,是为了下面的数据增强----裁剪
    height, width = img.shape[0], img.shape[1]
    #图片的中心点,为下面裁剪(crop)做准备,图片中心点用图片宽、高除以2算出,而不是用坐标值算出,是因为原点在图片的左上角,所以图片宽、高除2的结果就是中点的坐标。
    c = np.array([img.shape[1] / 2., img.shape[0] / 2.], dtype=np.float32)
    if self.opt.keep_res:
      input_h = (height | self.opt.pad) + 1
      input_w = (width | self.opt.pad) + 1
      s = np.array([input_w, input_h], dtype=np.float32)
    else:
      s = max(img.shape[0], img.shape[1]) * 1.0  #s:最长的边长,为下面的随机缩放做准备
      input_h, input_w = self.opt.input_h, self.opt.input_w
    #================接下来为了保持数据的泛化性,对数据进行一系列处理,得到我们第一个所需要的输入图像inp(input)=========
    flipped = False
    if self.split == 'train':
      #裁剪,当不随机裁剪时,进行如下数据增强
      if not self.opt.not_rand_crop:
        #随机缩放(介于0.61.3之间)
        s = s * np.random.choice(np.arange(0.6, 1.4, 0.1))
        #self._get_border函数是为了保证下面的np.random.randint(low=? , high=?)中的low值小于high值。
        w_border = self._get_border(128, img.shape[1])
        h_border = self._get_border(128, img.shape[0])
        c[0] = np.random.randint(low=w_border, high=img.shape[1] - w_border)
        c[1] = np.random.randint(low=h_border, high=img.shape[0] - h_border)
      else:
        sf = self.opt.scale
        cf = self.opt.shift
        c[0] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
        c[1] += s * np.clip(np.random.randn()*cf, -2*cf, 2*cf)
        s = s * np.clip(np.random.randn()*sf + 1, 1 - sf, 1 + sf)
      
      if np.random.random() < self.opt.flip:
        flipped = True
        img = img[:, ::-1, :]
        c[0] =  width - c[0] - 1
        
    #get_affine_transform:得到仿射变换矩阵 ;trans_input:用于输入缩放的仿射变换矩阵
    trans_input = get_affine_transform(
      c, s, 0, [input_w, input_h])
    #输入图像经过仿射变换的缩放,缩放成网络的固定输入(input_w, input_h)384×382512×512
    inp = cv2.warpAffine(img, trans_input, 
                         (input_w, input_h),
                         flags=cv2.INTER_LINEAR)
    inp = (inp.astype(np.float32) / 255.)
    #no_color_aug :没有颜色增强
    if self.split == 'train' and not self.opt.no_color_aug:
      color_aug(self._data_rng, inp, self._eig_val, self._eig_vec)
    # 归一化,利于梯度下降,有助于收敛,一般都有归一化这步
    inp = (inp - self.mean) / self.std
    inp = inp.transpose(2, 0, 1)
    #==============接着我们需要完成我们的heatmap的生成。===============
    output_h = input_h // self.opt.down_ratio
    output_w = input_w // self.opt.down_ratio
    num_classes = self.num_classes
    #这里为什么要进行get_affine_transform????????
    (有待商定)为下面的bbox变换做准备,trans_output:输出的仿射变换矩阵
    trans_output = get_affine_transform(c, s, 0, [output_w, output_h])

    hm = np.zeros((num_classes, output_h, output_w), dtype=np.float32)
    # 中心点宽高(self.max_objs*2)
    wh = np.zeros((self.max_objs, 2), dtype=np.float32)
    dense_wh = np.zeros((2, output_h, output_w), dtype=np.float32)
    # reg是偏置回归数组,记录下采样带来的误差(误差产生原因 :因为将float转成int类型而产生的误差),返回self.max_objs*2的小数
    reg = np.zeros((self.max_objs, 2), dtype=np.float32)
    #目标中心点在特征图中的索引,返回self.max_objs个ind(索引)
    ind = np.zeros((self.max_objs), dtype=np.int64)
    #reg_mask回归的是有无目标,以掩码mask是否等1表示,返回self.max_objs个回归mask
    #这里相当于记载一张图片存在哪些目标,有的话对应索引设置为1,其余设置为0。
    reg_mask = np.zeros((self.max_objs), dtype=np.uint8)
    cat_spec_wh = np.zeros((self.max_objs, num_classes * 2), dtype=np.float32)
    cat_spec_mask = np.zeros((self.max_objs, num_classes * 2), dtype=np.uint8)

    #---------------用高斯函数画出heatmap----------------
    draw_gaussian = draw_msra_gaussian if self.opt.mse_loss else \
                    draw_umich_gaussian

    gt_det = []
    for k in range(num_objs):
      ann = anns[k]#anns是一个列表,其元素是字典,每一个字典描述一个对象,即,K代表图片中的某一个标注对象.
      bbox = self._coco_box_to_bbox(ann['bbox'])
      #cls_id :类别cls对应的索引
      cls_id = int(self.cat_ids[ann['category_id']])
      if flipped:
        bbox[[0, 2]] = width - bbox[[2, 0]] - 1
      #这为什么要对bbox进行进行仿射变换???????
      #(有待商讨)将bbox根据模型输出尺寸进行缩放----是因为前面的输入图片已经进行了仿射变换,得到了新的输入,原图像上的bbox自然也改变了,所以也要将原图像上的bbox进行仿射变换,以和变换后的输入对应。
      bbox[:2] = affine_transform(bbox[:2], trans_output)
      bbox[2:] = affine_transform(bbox[2:], trans_output)
      #做clip(裁剪)的目的是,保证仿射变换后的左上角、右下角的坐标还在[0,output_w(output_h) - 1]内。
      bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0, output_w - 1)
      bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0, output_h - 1)
      h, w = bbox[3] - bbox[1], bbox[2] - bbox[0]
      if h > 0 and w > 0:
        #ceil() 函数返回数字的上入整数。j即,向上取整
        radius = gaussian_radius((math.ceil(h), math.ceil(w)))
        radius = max(0, int(radius))
        radius = self.opt.hm_gauss if self.opt.mse_loss else radius
        ct = np.array(
          [(bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2], dtype=np.float32)
        ct_int = ct.astype(np.int32)
        #hm[cls_id] :hm的维度(num_classes, output_h, output_w),是一个3维度矩阵,则hm[cls_id]表示一个类(cls_id)对应的一整个通道.
        draw_gaussian(hm[cls_id], ct_int, radius)
        wh[k] = 1. * w, 1. * h # 目标矩形框的宽高(目标尺寸)的损失
        #目标中心点在特征图中的索引 : 目标中心点的横坐标ct_int[0]代表该中心点处在第几列(即,某一行的第几个),而纵坐标ct_int[1]代表该中心点处在第几行;
        #又因为output_w 是特征图的宽(即,每一行有多少个特征点),所以ct_int[1] * output_w + ct_int[0]表示中心点处在特征图上的第几个特征点,从特征图左上角一行一行的开始计数。
        ind[k] = ct_int[1] * output_w + ct_int[0]
        #reg是偏置(offset)回归数组,存放每个中心点的偏置值,k是当前图中第k个目标
        reg[k] = ct - ct_int
        reg_mask[k] = 1   #有目标的位置的mask
        #一个目标回归w、h两个元素,对应下面两行代码中2的含义
        cat_spec_wh[k, cls_id * 2: cls_id * 2 + 2] = wh[k]
        cat_spec_mask[k, cls_id * 2: cls_id * 2 + 2] = 1
        if self.opt.dense_wh:
          draw_dense_reg(dense_wh, hm.max(axis=0), ct_int, wh[k], radius)
        gt_det.append([ct[0] - w / 2, ct[1] - h / 2, 
                       ct[0] + w / 2, ct[1] + h / 2, 1, cls_id])
    
    ret = {'input': inp, 'hm': hm, 'reg_mask': reg_mask, 'ind': ind, 'wh': wh}
    if self.opt.dense_wh:
      hm_a = hm.max(axis=0, keepdims=True)
      dense_wh_mask = np.concatenate([hm_a, hm_a], axis=0)
      ret.update({'dense_wh': dense_wh, 'dense_wh_mask': dense_wh_mask})
      del ret['wh']
    elif self.opt.cat_spec_wh:
      ret.update({'cat_spec_wh': cat_spec_wh, 'cat_spec_mask': cat_spec_mask})
      del ret['wh']
    if self.opt.reg_offset:
      #ret更新, 即把reg这个item添加到ret字典中去。
      # 原文中提到加不加偏置(offset)回归影响不大,所以这里可选择加不加reg
      ret.update({'reg': reg})
    if self.opt.debug > 0 or not self.split == 'train':
      gt_det = np.array(gt_det, dtype=np.float32) if len(gt_det) > 0 else \
               np.zeros((1, 6), dtype=np.float32)
      meta = {'c': c, 's': s, 'gt_det': gt_det, 'img_id': img_id}
      ret['meta'] = meta
    return ret
  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值