之前在学习pytorch框架时候,有个老师说整个的使用过程,不外乎四个过程:1)数据准备(预处理、数据加载等);2)模型定义;3)损失函数与优化器定义;4)进行训练
下面我将通过这样四个步骤对代码进行一个阅读。(划重点: 自我理解!)
1. 数据准备
文章使用的是COCO的数据集,对数据的基本处理在$ROOT/core/dbs下面,文件组织关系BASE–>DETECTION–>COCO,base中定义一些基本属性,detection中对参数进行了初始化,在coco文件中进行了与训练任务相关数据集的信息,如均值方差、类别等。在这里也实现了对数据的加载,以及对标签的转换,转换为[x1,y1,x2,y2,class]的数组。
2. 模型构建
CornerNet-Lite中有三个网络,位于$ROOT/models下
2.1 CornerNet
class model继承了hg_net,但这里的hg_net并不是真真正的沙漏网络结构,真正的hg在$ROOT/core/models/py_utils/moudules.py中的class hg(nn.Module),叙述具体细节之前,先看一下CornerNet各模块的定义流程,
在CornerNet.py中的class model(hg_net):我们可以看到如下几个函数:
# 此模块用于在网络最后接预测(heatmap、tag、offset)
def _pred_mod(self, dim):
return nn.Sequential(
# 使用1x1kernel将通道数变为dim
convolution(3, 256, 256, with_bn=False),
nn.Conv2d(256, dim, (1, 1))
)
# 主要用来做归一化
def _merge_mod(self):
return nn.Sequential(
nn.Conv2d(256, 256, (1, 1), bias=False),
nn.BatchNorm2d(256)