CenterNet(Objects as points)开源代码:
https://github.com/xingyizhou/CenterNethttps://github.com/xingyizhou/CenterNet
1、环境安装
在GitHub下载centernet代码,并按照CenterNet中的reademe-Installation部分搭建centernet运行环境
需要注意的是:在原版Installation部分需要安装pytorch0.4(因为用到的DCNV2对应了这个pytorch),根据bilibili up主-Guuuuuu老师儿 讲解的使用Pytorch搭建centernet视频,得知可使用1.4版本的pytorch以及下载github上最新的dcnv2,搭配使用可run centernet。但是我的显卡支持1.10版本以上的,所以1.4的pytorch也不能满足我的需求,为了快速复现centernet,没有在搭配环境上下太多功夫,按照bilibili视频安装了pytorch1.4及替换了DCNV2文件夹(替换以后还需要build等操作具体见视频)。在centernet代码中使用cuda加速部分注释掉使用cpu进行训练,这样和pytorch版本也就没什么关系了(这么说按照原版的安装0.4其实也可以。。。)
ps,视频中讲解运行demo.py 使用命令行运行,为了使用vscode进行调试,其他的配置参数可以加在.lanuch.json文件中
"args": [
"ctdet",
"--demo","../images/19064748793_bb942deea1_k.jpg",
"--load_model","../models/ctdet_coco_dla_2x.pth"
]
2、centernet代码说明
(1)配置文件opts.py
配置使用centernet进行哪种任务,使用哪种数据集,数据集格式使用coco | kitti | coco_hp | pascal。配置主干网络模型使用'res_18 | res_101 | resdcn_18 | resdcn_101 |'dlav0_34 | dla_34 | hourglass等配置。
class opts(object):
def __init__(self):
self.parser = argparse.ArgumentParser()
# basic experiment setting
self.parser.add_argument('task', default='ctdet',
help='ctdet | ddd | multi_pose | exdet')
self.parser.add_argument('--dataset', default='coco',
help='coco | kitti | coco_hp | pascal')
self.parser.add_argument('--exp_id', default='default')
self.parser.add_argument('--test', action='store_true')
(2)制作自己数据集的真值
参考从代码角度分析高效优雅检测模型CenterNet - 作业部落 Cmd Markdown 编辑阅读器
整个真值生成过程代码在src/lib/datasets/sample/ctdet.py,其外部采用的是多继承的方式实现dataset,
def get_dataset(dataset, task):
class Dataset(dataset_factory[dataset], _sample_factory[task]):
pass
return Dataset
a、获取数据集及标签路径等基础信息
其中dataset_factory[dataset]是COCO数据解析类,_sample_factory[task]是目标检测的CTDetDataset,这两个类都是继承至pytorch的Dataset类。
对于python多继承而言,是按照先后顺序继承的,COCO类实现获得数据集路径以及标签json文件的路径及名称,对象类别数量,images的id等。而CTDetDataset才是真正实现了getitem方法,两个类的全部方法和属性合并才得到最终的datalayer层。这样写的好处很明显就是解耦,如果数据格式变了或是说是coco格式,但是内部变量数据值变了,此时就可以仅仅额外提供一个和COCO类一样的py文件即可,而不需要重写CTDetDataset类。但是这样写的缺点也很明显:代码可阅读性降低了很多,而且在CTDetDataset里面强制读取COCO类的属性,是没有代码提示的,因为如果不是多继承的写法,实际运行时候肯定是会报错了,加入了多继承后,子类就可以读取到父类里面的任何一个方法和属性。虽然看起来很优雅,但是这种实现方式不推荐,严重违背迪米特法则。
COCO这个类仅仅是为后面的CTDetDataset提供一些数据和属性。
b、获取标签的具体信息如annotations里的信息
Dataloader如何加载json中的annotations
参照【PyTorch深度学习实践】学习笔记 数据集的加载Dataset和DataLoader原理_咯吱咯吱咕嘟咕嘟的博客-CSDN博客
上面的博客写的很详细了,我这里简单写一下流程
b1、搭建dataloader
DataLoader( dataset = dataset , #dataset 是继承了dataset类之后加载数据集提供路径
batch_size = 32, #选择batch_size的大小
shuffle = true, #增强数据集随机性
num_workers = 2 ) #多进程读数据
在enumerate(Dataloader)中具体读到json文件中的annotations
for i, data in enumerate(train_loader):
若是没有跳转到dataloader.py文件中需要在lanuch.json文件中添加如下配置信息
"justMyCode": false,
b2、跳转到
def __iter__(self):
if self.num_workers == 0:
return _SingleProcessDataLoaderIter(self)
else:
return _MultiProcessingDataLoaderIter(self) #进程问题
b3、跳转到
class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
def __init__(self,loader):
super(_SingleProcessDataLoaderIter,self).__init__(loader)
assert self.timeout == 0
assert self.num_workers == 0
self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset,self.auto_collation, self.collate_fn, self.drop_last)
def __next__(self):
index = self._next_index() # may raise StopIteration
data = self.dataset_fetcher.fetch(index) # may raise StopIteration
if self.pin_memory:
data = _utils.pin_memory.pin_memory(data)
return data
next = __next__ # Python 2 compatibility
b4、跳转到
def _next_index(self):
return next(self.sampler_iter) # may raise StopIteration
b5、跳转到
class _MapDatasetFetcher(_BaseDatasetFetcher):
def __init__(self, dataset, auto_collation, collate_fn, drop_last):
super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
def fetch(self, possibly_batched_index):
if self.auto_collation:
data = [self.dataset[idx] for idx in possibly_batched_index] #调用了dataset,通过一系列的data拼接成一个list;
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
b6、跳转到CTDetDataset类的getitem函数
class CTDetDataset(data.Dataset):
def __getitem__(self, index):
在这个函数中我们对读取到的json进行制作,制作我们想要的真值。
在我的数据集中,原图为1280,标签也为1280。将每一个json合并到数据集的大json文件时,将图片及标签都转为了512大小(如何转?直接将json文件中的坐标*512/1280)
b6-1、取出anns中的前角点,后角点,chock点,edge点以及是否被占等信息存入joint中
b6-2、将这些点即joint再缩小为128,对这些点画gussian。
b6-3、将得到的gussian,写入图片