一、写在开头
这次的需求是:在Faceboxes的基础上,参照Retinaface加入了关键点信息。同时学习下Retinaface的代码和pytorch框架。
Retinaface论文地址
Retinaface代码地址
Faceboxes代码地址
主要以记录Retinaface为主,Faceboxes中需要参照前面修改的,会指出来。
二、主要内容
主要围绕其train.py脚本展开,从数据处理,默认框生成,网络结构,损失函数,几个方面记录。
数据处理
:
1.在train.py中相关联的主要是下面几行:
dataset = WiderFaceDetection(training_dataset,preproc(img_dim, rgb_mean))
batch_iterator = iter(data.DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers, collate_fn=detection_collate))
images, targets = next(batch_iterator)
2.所涉及的脚本:
wider_face.py
data_augment.py
默认框生成
:
1.在train.py中相关联的主要是下面几行:
priorbox = PriorBox(cfg, image_size=(img_dim, img_dim))
with torch.no_grad():
priors = priorbox.forward()
priors = priors.cuda()
2.涉及的脚本
prior_box.py
网络结构
:
1.在train.py中相关联的主要是下面几行:
net = RetinaFace(cfg=cfg)
out = net(images)
2.涉及的脚本
net.py
retinaface.py
损失函数
:
1.在train.py中相关联的主要是下面几行:
criterion = MultiBoxLoss(num_classes, 0.35, True, 0, True, 7, 0.35, False)
loss_l, loss_c, loss_landm = criterion(out, priors, targets)
2.涉及的脚本
multibox_loss.py
box_utils.py