目录
一、前言
过年节假日期间,将Google最新出来的BoTNet用于人脸识别,使用pytorch版本的BoTNet+ArcFace对人脸数据进行了训练。
二、训练准备
1、BoTNet的Pytorch版本
2、人脸识别数据和代码架构
(1)人脸识别数据和代码架构用的是https://github.com/TreB1eN/InsightFace_Pytorch。下载该工程,解压,在当前目录中建立文件夹:BoTNet。将BoTNet的model.py放入BoTNet文件夹中。
(2)修改config.py文件
在conf.use_mobilfacenet下面加入:
conf.use_mobilfacenet = False
conf.use_transfomer = False
(3)修改Learner.py文件
添加:
from BoTNet.model import ResNet,Bottleneck
在类face_learner中的__init__中引入conf.use_transfomer:
if conf.use_mobilfacenet:
self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
print('MobileFaceNet model generated')
elif conf.use_transfomer:
print("ues_transfomer")
self.model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=self.class_num).to(conf.device)
(4)修改train.py文件,加入:
conf.use_mobilfacenet = False
conf.use_transfomer = True
conf.lr = args.lr
conf.batch_size = 32#args.batch_size
conf.num_workers = args.num_workers
conf.data_mode = args.data_mode
learner = face_learner(conf)
注:这里因电脑本身CUDA内存原因,只能使用batch_size = 32。
3、完整训练代码
conf.use_mobilfacenet = False
conf.use_botnet = True
conf.use_t2tvit = False
说明:上述三个参数均为False时,使用的是ResNet(默认是ir_se,depth=50)进行人脸识别训练;剩下的哪个为True,使用对应的模型进行训练。
人脸识别数据参考源代码(https://github.com/TreB1eN/InsightFace_Pytorch)中的获取方式:
三、训练和结果
1、训练
保持原github上其余参数的设置,运行:
python3 train.py
2、结果
经过20个epochs的训练后,如下图所示,在LFW、CFP-FP、AgeDB-30中的识别率分别为:98.98%、92.18%、90.29%,和原github上提供的效果(LFW 0.9952、CFP-FP 0.9504、AgeDB-30 0.9622)还是有几个点的差距。后面通过调优看看能否提高效果。
LFW数据集在训练过程中的测试结果:
CFP-FP数据集在训练过程中的测试结果:
AgeDB-30数据集在训练过程中的测试结果