人脸素描属性识别:深度学习模型探索与性能评估
github代码:https://github.com/linkcao/FS2K_extract
FS2K数据集:https://github.com/DengPingFan/FS2K
项目概述
本项目需要根据FS2K数据集
进行训练和测试,实现输入一张人脸图片(真实图片或者素描图片),输入该图片的属性特征信息,提取属性特征包括hair
(有无头发)、hair_color
(头发颜色)、gender
(图像人物性别)、earring
(是否有耳环)、smile
(是否微笑)、frontal_face
(是否歪脖)、style
(图片风格),详细信息均可通过FS2K的anno_train.json
和anno_test.json
获取,本质是一个多标签分类问题。
- 本文探索了三种深度学习模型:VGG16、ResNet18和DenseNet121在该任务下的性能表现。
- 实验数据集
FS2K数据集
与实验结果均可在所给github仓库中获取,其中photo
代表原图,Sketch
代表素描图,数据集如下图所示:
处理方案
首先对于FS2K数据集用官方的数据划分程序进行划分,之后对划分后的数据进行预处理,统一图片后缀为jpg,之后自定义数据加载类,在数据加载过程中进行标签编码,对图片大小进行统一,并转成tensor,在处理过程中发现存在4个通道的图片,本文采取取前3个通道的方案,之后再对图像进行标准化,可以加快模型的收敛,处理完成的数据作为模型的输入,在深度学习模型方面,首先需要进行模型选择,本文使用了三个模型,分别为VGG16,ResNet121以及DenseNet121,在通过pytorch预训练模型进行加载,并修改模型输出层,输出数量为图片属性特征数,之后在设定模型训练的参数,包括Batch,学习率,epoch等,在每一轮训练完成后,都需要对预测出的特征进行处理,在二分类标签设定概率阈值,多分类标签特征列则进行最大概率类别组合,取预测概率最大的类别作为当前属性的预测结果,每一轮训练都在测试集上进行性能评估,并根据F1指标择优保存模型。训练完成后,在测试集上预测属性提取结果,对每一个属性进行性能评估,最后取平均,得到平均的性能指标。
整体的处理流程如下图所示:
数据预处理
- 数据划分: 根据FS2K官方给出的数据划分得到训练集和测试集
- 图片统一后缀和通道数: 统一图片后缀为jpg,通道数为3
- 统一图片大小:所给数据集分为三个文件夹,每个文件夹图片的像素各不相同,分别为250*250、475 *340、223 *318,这里统一变换成256 * 256,便于后序处理
- 转换为Tensor: 我们将图片数据转换为PyTorch的Tensor格式
- 图像标准化:逐channel的对图像进行标准化,可以加快模型的收敛
标签编码
- 二分类标签编码: 对于二分类属性(如
hair
、earring
等),我们将标签编码为0和1 - 多分类标签编码: 对于多分类属性(如
hair_color
、style
等),我们采用One-Hot编码进行处理,以便模型能够正确识别多个类别,例如:hair_color
中0 对应 [1,0,0,0,0], 1对应[0,1,0,0,0], 2对应[0,0,1,0,0],以此类推,共5类style
中 0 对应 [1,0,0],1对应[0,1,0], 2对应[0,0,1],以此类推,共3类
- 标签向量拼接: 将所有属性标签拼接成一个长度为13的标签向量,方便模型训练和预测。
实验模型
- 模型选择: 我们选择了VGG16、ResNet121和DenseNet121三种预训练模型作为候选模型
- 模型加载和修改: 使用PyTorch加载预训练模型,并根据任务需求修改输出层,使其输出属性特征数
- 训练参数设置: 设定训练参数,包括Batch大小、学习率、epoch等
- 性能评估: 每轮训练后,在测试集上进行性能评估,并根据F1指标选择最佳模型进行保存
VGG16
模型结构参数
由于VGG16最后一层全连接输出1000维特征,因此在本题中需要在加一层全连接输入1000维特征,输出13维特征,最后再加上一层sigmoid
激活函数,在得到每一类预测的概率后,针对编码过的hair_color、style的8列,对各自的编码后的对应列计算概率最大的列下标,作为该属性的预测值。
训练参数
batch | 64 |
---|---|
epoch | 20 |
optimizer (优化器) | SGD(随机梯度下降) |
criterion (损失函数) | BCELoss(二分类交叉熵损失) |
学习率 | 0.01 |
photo数据集上模型训练Loss
VGG16 实验结果 「方法一」
f1 | precision | recall | accuracy | |
---|---|---|---|---|
hair | 0.926064 | 0.903045 | 0.950287 | 0.950287 |
gender | 0.598046 | 0.611282 | 0.59369 | 0.59369 |
earring | 0.74061 | 0.674408 | 0.821224 | 0.821224 |
smile | 0.513038 | 0.580621 | 0.639579 | 0.639579 |
frontal_face | 0.758024 | 0.694976 | 0.833652 | 0.833652 |
hair_color | 0.351596 | 0.387132 | 0.389101 | 0.389101 |
style | 0.460469 | 0.526145 | 0.443595 | 0.443595 |
average | 0.668481 | 0.672201 | 0.708891 | 0.708891 |
结果分析
- 在hair、frontal_face和earring属性上取得了较高的性能,分别达到了95.03%、83.37%和82.12%的准确率。
- 对于gender和smile属性,性能较差,分别只达到了59.37%和63.96%的准确率。
- 在hair_color和style属性上的性能也较为一般,分别只达到了38.91%和44.36%的准确率。
- 平均性能指标为约66.85%,整体表现中等。
ResNet18
模型结构参数
模型修改 ,模型最后加一层全连接输入1000维特征,输出13维特征,最后再加上一层sigmoid
激活函数
训练参数
batch | 64 |
---|---|
epoch | 20 |
optimizer (优化器) | SGD(随机梯度下降) |
criterion (损失函数) | BCELoss(二分类交叉熵损失) |
学习率 | 0.01 |
photo数据集上模型训练Loss
ResNet18 photo数据集结果 「方法二」
f1 | precision | recall | accuracy | |
---|---|---|---|---|
hair | 0.926064 | 0.903045 | 0.950287 | 0.950287 |
gender | 0.657874 | 0.657195 | 0.6587 | 0.6587 |
earring | 0.744185 | 0.764809 | 0.821224 | 0.821224 |
smile | 0.634135 | 0.63298 | 0.652008 | 0.652008 |
frontal_face | 0.758024 | 0.694976 | 0.833652 | 0.833652 |
hair_color | 0.498804 | 0.515916 | 0.546845 | 0.546845 |
style | 0.508202 | 0.57917 | 0.482792 | 0.482792 |
average | 0.715911 | 0.718511 | 0.743188 | 0.743188 |
结果分析:
- 相较于VGG16,在gender和smile属性上取得了更好的性能,分别达到了65.87%和65.20%的准确率。
- 在hair、earring和frontal_face属性上性能相近,分别达到了95.03%、82.12%和83.37%的准确率。
- 对于hair_color和style属性,性能仍然较差,分别只达到了54.68%和48.28%的准确率。
- 平均性能指标为约71.59%,略优于VGG16模型。
Sketch数据集上模型训练Loss
sketch数据集结果 「方法三」
f1 | precision | recall | accuracy | |
---|---|---|---|---|
hair | 0.926064 | 0.903045 | 0.950287 | 0.950287 |
gender | 0.811982 | 0.813721 | 0.814532 | 0.814532 |
earring | 0.743495 | 0.720011 | 0.813576 | 0.813576 |
smile | 0.573169 | 0.573085 | 0.614723 | 0.614723 |
frontal_face | 0.758024 | 0.694976 | 0.833652 | 0.833652 |
hair_color | 0.358576 | 0.339481 | 0.419694 | 0.419694 |
style | 0.842575 | 0.942995 | 0.803059 | 0.803059 |
average | 0.751736 | 0.748414 | 0.78119 | 0.78119 |
DenseNet121
模型结构参数
训练参数
batch | 64 |
---|---|
epoch | 20 |
optimizer (优化器) | SGD(随机梯度下降) |
criterion (损失函数) | BCELoss(二分类交叉熵损失) |
学习率 | 0.01 |
photo数据集上模型训练Loss
DenseNet photo数据集结果 「方法四」
f1 | precision | recall | accuracy | |
---|---|---|---|---|
hair | 0.926064 | 0.903045 | 0.950287 | 0.950287 |
gender | 0.935669 | 0.936043 | 0.935946 | 0.935946 |
earring | 0.837358 | 0.837194 | 0.853728 | 0.853728 |
smile | 0.784984 | 0.787445 | 0.790631 | 0.790631 |
frontal_face | 0.780436 | 0.832682 | 0.8413 | 0.8413 |
hair_color | 0.685242 | 0.665904 | 0.718929 | 0.718929 |
style | 0.515421 | 0.567896 | 0.497132 | 0.497132 |
avg | 0.808147 | 0.816276 | 0.823494 | 0.823494 |
结果分析:
- 在gender和smile属性上取得了最佳性能,分别达到了93.57%和78.50%的准确率。
- 对于其他属性的性能也较为优秀,在hair、earring和frontal_face属性上达到了92.60%、83.73%和78.04%的准确率。
- 在hair_color属性上的性能相对较差,但仍然达到了68.52%的准确率。
- 在style属性上取得了最低的性能,仅为51.54%的准确率。
- 平均性能指标为约80.81%,在三个模型中表现最佳。
Sketch数据集上模型训练Loss
DenseNet sketch数据集结果 「方法五」
f1 | precision | recall | accuracy | |
---|---|---|---|---|
hair | 0.926064 | 0.903045 | 0.950287 | 0.950287 |
gender | 0.883773 | 0.886639 | 0.885277 | 0.885277 |
earring | 0.743196 | 0.734733 | 0.819312 | 0.819312 |
smile | 0.610952 | 0.661847 | 0.671128 | 0.671128 |
frontal_face | 0.758024 | 0.694976 | 0.833652 | 0.833652 |
hair_color | 0.372596 | 0.360252 | 0.423518 | 0.423518 |
style | 0.944535 | 0.96071 | 0.938815 | 0.938815 |
avg | 0.779892 | 0.775275 | 0.815249 | 0.815249 |
整体结果对比
FS2K Photo数据集下各模型F1值:
FS2K Sketch数据集下各模型F1值:
- VGG16 vs. ResNet18 vs. DenseNet121:
- 在对比三种不同的预训练模型在相同数据集上的性能时,可以观察到 DenseNet121 在大多数属性上取得了最佳性能,其次是 ResNet18,最后是 VGG16。这表明了在相同的任务上,更深层次的模型结构往往能够更好地提取图像特征,从而提高模型的性能。
- 属性性能差异:
- 在所有模型中,对于一些属性如 hair、frontal_face 和 earring,模型性能普遍较好,这可能是因为这些属性在图像中具有明显的特征,容易被模型学习和识别。相比之下,像 gender 和 smile 这样的属性可能更为主观和抽象,因此模型的性能相对较低。
- 特别是在 hair_color 和 style 属性上,模型性能普遍较差,这可能是因为这些属性的识别更为复杂,受到光照、姿势、背景等因素的影响较大。
- 平均性能指标对比:
- 在三种模型中,DenseNet121 在平均性能指标上取得了最佳表现,这意味着它在多个属性的识别上具有更稳定和优秀的性能。ResNet18 次之,VGG16 表现最差。