年龄性别预测3:Android实现年龄性别预测和识别(含源码,可实时预测)

年龄性别预测3:Android实现年龄性别预测和识别(含源码,可实时预测)

目录

年龄性别预测3:Android实现年龄性别预测和识别(含源码,可实时预测)

1.年龄性别预测和识别方法

2.人脸检测方法

3.年龄性别预测和识别模型训练

(1)年龄性别预测和识别模型训练

(2) 将Pytorch模型转换ONNX模型

(3) 将ONNX模型转换为TNN模型

4.年龄性别预测和识别模型Android部署

 (1) Android开发版本

(2) Android端上部署模型

(3) Android测试效果 

(4) 运行APP闪退:dlopen failed: library "libomp.so" not found

5.项目源码下载


本项目将实现年龄性别预测和识别,整套方案采用二阶段方法实现,即首先使用通用的人脸检测算法(Face Detection)定位人脸区域,裁剪人脸,再构建多任务模型,分别进行年龄预测(Age)和性别识别(Gender)。项目分为数据集说明,Pytorch模型训练和C++/Android部署等多个章节,本篇是项目《年龄性别预测》系列文章之Android实现年龄性别预测和识别;本篇主要分享将Python训练后的年龄性别预测和识别模型移植到Android平台。我们将开发一个简易的、可实时运行的年龄性别识别的Android Demo。项目将手把手教你将训练好的模型部署到Android平台中,包括如何转为ONNX,TNN模型,并移植到Android上进行部署,实现一个年龄性别识别的Android Demo APP 。APP在普通Android手机上可以达到实时的检测识别效果,CPU(4线程)约30ms左右,GPU约25ms左右 ,基本满足业务的性能需求。

c4269d333aac4094a47166dc1c9e0b99.png

先展示一下Android版本年龄性别识别Demo效果:https://download.csdn.net/download/guyuealian/88743711

06806202ad494d489ad44d072b01a949.gifec2a11af44ba4efdb09d2f5200bdbb66.gif

项目基于深度学习Pytorch,构建了整套年龄性别预测和识别模型训练和测试框架;项目源码backbone模型支持的有resnet18,resnet50,以及轻量化模型mobilenet_v2等常见的深度学习模型,用户也可自定义模型进行训练;准确率还挺高的,采用轻量级mobilenet_v2模型的性别识别准确率0.9603左右,年龄预测MAE(平均绝对误差:)3.1935左右,CS5(预测年龄与真实年龄的绝对误差不过5年的准确率)0.8021左右,基本满足业务性能需求。

考虑到resnet18和resnet50模型计算量比较大,不合适部署到Android平台,本篇Android源码只部署mobilenet_v2模型实现年龄性别预测

 年龄性别预测和识别Android APP Demo体验:https://download.csdn.net/download/guyuealian/88743711

模型input size性别准确率年龄MAE年龄CS3年龄CS5

AE_mobilenet_v2

112×112

0.9603

3.1935

0.5969

0.8021

AE_resnet18

112×112

0.9606

3.1795

0.5956

0.8010

AE_resnet50

112×112

0.9609

3.2008

0.5900

0.8035

 【尊重原创,转载请注明出处】 https://blog.csdn.net/guyuealian/article/details/135556824 


 更多项目《年龄性别预测》和《面部表情识别》系列文章请参考:

  1. 面部表情识别1:表情识别数据集(含下载链接)
  2. 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)
  3. 面部表情识别3:Android实现表情识别(含源码,可实时检测)
  4. 面部表情识别4:C++实现表情识别(含源码,可实时检测)
  5. 年龄性别预测1:年龄性别数据集说明(含下载地址)
  6. 年龄性别预测2:Pytorch实现年龄性别预测和识别(含训练代码和数据)
  7. 年龄性别预测3:Android实现年龄性别预测和识别(含源码,可实时预测)
  8. 年龄性别预测4:C/C++实现年龄性别预测和识别(含源码,可实时预测)

139a07d0b14741b4a7b80dee106c2b1f.gif


1.年龄性别预测和识别方法

年龄性别预测和识别方法有多种实现方案,这里采用最常规的二阶段方法实现,即首先使用通用的人脸检测算法(Face Detection)定位人脸区域,裁剪人脸,再构建多任务模型,分别进行年龄预测(Age)和性别识别(Gender)。

  • 人脸检测:人脸检测算法已经有很多成熟开源项目了,本项目不作分析,可以参考使用MTCNN,DSFD,FaceBoxes,RFB等方法。
  • 性别识别:性别识别是一个简单二分类,训练使用交叉熵损失函数即可
  • 年龄预测:可以采用年龄分类方法,如SSR-Net模型,也可以结合回归的方法,如Label Distribution。就调研而言,基于Label Distribution的方法会比分类方法准确率会高一些。

下图本项目构建的年龄性别预测和识别模型,其中

  • Backbone:主杆网络,用于提取人脸图像特征,可以使用任意的骨干网络,如resnet18,resnet34,resnet50以及轻量化模型mobilenet_v2等
  • Gender-branch: 性别识别分支,用于对性别进行分类识别,损失函数使用交叉熵
  • Age-branch: 年龄预测分支,对年龄进行预测,损失函数可以使用交叉熵或者Label Distribution,项目设定最大周岁是70周岁,训练数据中年龄大于70的lalel,会重置为70。

698d2400a5674176bc14e5c47660a294.png


2.人脸检测方法

本项目人脸检测训练代码请参考:https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB 

这是一个基于SSD改进且轻量化后人脸检测模型,很轻巧,整个模型仅仅1.7M左右,在普通Android手机都可以实时检测。人脸检测方法在网上有一大堆现成的方法可以使用,如MTCNN,DSFD,FaceBoxes,完全可以不局限我这个方法。

da566a77c27c6d21c8d7ca477782e4d8.jpeg​​

关于人脸检测的方法,可以参考我的另一篇博客:

行人检测和人脸检测和人脸关键点检测(C++/Android源码)


3.年龄性别预测和识别模型训练

(1)年龄性别预测和识别模型训练

本篇博文不含Python版本的模型以及相关训练代码,关于年龄性别预测和识别模型的训练方法,请参考本人另一篇博文《Pytorch实现年龄性别预测和识别(含训练代码和数据)》;

Python项目源码backbone模型支持的有resnet18,resnet50,以及轻量化模型mobilenet_v2等常见的深度学习模型,用户也可自定义模型进行训练;准确率还挺高的,采用轻量级mobilenet_v2模型的性别识别准确率0.9603左右,年龄预测MAE(平均绝对误差:)3.1935左右,CS5(预测年龄与真实年龄的绝对误差不过5年的准确率)0.8021左右,基本满足业务性能需求。

模型input size性别准确率年龄MAE年龄CS3年龄CS5

AE_mobilenet_v2

112×112

0.9603

3.1935

0.5969

0.8021

AE_resnet18

112×112

0.9606

3.1795

0.5956

0.8010

AE_resnet50

112×112

0.9609

3.2008

0.5900

0.8035

关于模型训练以及测试的建议:

  1. ​ 关于性别识别的问题:目前性别识别的准确率约96%,识别错误的主要有两种情况,(1)儿童性别容易误识别,特别是1~3岁左右的儿童,性别识别比较困难 (2) 长头发的男生或者短头发的女生的,也容易误识别;其他情况,正常穿着打扮的男士女生识别准确率可以达到99%左右。
  2. 关于年龄预测的问题:现有数据年龄部分不均匀,大部分人脸数据年龄分布在20-40岁之间的年轻人,而儿童和老年人的数据比较少;导致儿童和老年人年龄预测精准度比较差;另外,也是强烈建议的:采集同一个人不同年龄阶段的人脸数据加入模型训练,可以有效提升年龄预测的精准度。损失函数使用Label Distribution方法进行训练,也会比直接使用交叉熵损失函数效果要好。
  3. 当人脸存在遮挡时,如戴眼镜,戴口罩,头发遮挡,年龄预测的误差较大,建议实际使用过程中,尽量采集正脸,无遮挡的人脸图片进行测试
  4. 清洗数据集(最重要):尽管鄙人已经清洗一部分了,但还是建议你,训练前,再次清洗数据集,不然会影响模型的识别的准确率。
  5. 增加训练的样本数据: 建议根据自己的业务场景,采集相关数据,提高模型泛化能力
  6. 使用参数量更大的模型: 本教程使用的是mobilenet_v2模型,属于比较轻量级的分类模型,采用更大的模型(如resnet50),理论上其精度更高,但推理速度也较慢。
  7. 尝试不同数据增强的组合进行训练
  8. 增加数据增强: 已经支持: 随机裁剪,随机翻转,随机旋转,颜色变换等数据增强方式,可以尝试诸如mixup,CutMix等更复杂的数据增强方式
  9. 样本均衡: 原始数据年龄类别数据并不均衡,类别20-40岁的数据偏多,而老年人和小孩的数据偏少,这会导致训练的模型会偏向于样本数较多的类别。建议进行样本均衡处理。
  10. 调超参: 比如学习率调整策略,优化器(SGD,Adam等)
  11. 损失函数: 目前训练代码已经支持:交叉熵,LabelSmoothing,可以尝试FocalLoss等损失函数

(2) 将Pytorch模型转换ONNX模型

目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行C/C++端上部署。部署流程可分为四步:训练模型->将模型转换ONNX模型->将ONNX模型转换为TNN模型->C/C++部署TNN模型。

训练好Pytorch模型后,我们需要先将模型转换为ONNX模型,以便后续模型部署。

  • 原始项目提供转换脚本,你只需要修改model_file为你模型路径即可
  •  convert_torch_to_onnx.py实现将Pytorch模型转换ONNX模型的脚本
python libs/convert/convert_torch_to_onnx.py
"""
This code is used to convert the pytorch model into an onnx format model.
"""
import sys
import os
 
sys.path.insert(0, os.getcwd())
import argparse
from demo import Predictor
from basetrainer.utils import log, setup_config
from basetrainer.utils.converter import pytorch2onnx
 
 
def get_parser():
    # 配置文件
    config_file = "../../configs/config.yaml"
    # 模型文件
    # model_file = "../../work_space/AE_mobilenet_v2_1.0_L1Loss_20240103_191147_2420/model/best_model_073_0.8107.pth"
    model_file = "../../work_space/AE_mobilenet_v2_1.0_L1Loss_20240105_181151_9813/model/best_model_082_0.8021.pth"
    parser = argparse.ArgumentParser(description="Inference Argument")
    parser.add_argument("-c", "--config_file", help="configs file", default=config_file, type=str)
    parser.add_argument("-m", "--model_file", help="model_file", default=model_file, type=str)
    parser.add_argument("--use_age_ld", help="use age label distribution", default=1, type=int)
    parser.add_argument("--device", help="cuda device id", default="cuda:0", type=str)
    return parser
 
 
def convert_torch_to_onnx(cfg):
    p = Predictor(cfg=cfg)
    model = p.model
    w, h = cfg.input_size
    input_shape = (1, 3, h, w)
    onnx_file = str(cfg.model_file).replace(".pth", ".onnx")
    pytorch2onnx.convert2onnx(model,
                              input_shape=input_shape,
                              input_names=['input'],
                              output_names=['age', 'gender'],
                              onnx_file=onnx_file,
                              opset_version=9)
 
 
if __name__ == "__main__":
    parser = get_parser()
    print(parser.parse_args())
    cfg = setup_config.parser_config(parser.parse_args(), cfg_updata=False)
    convert_torch_to_onnx(cfg)

(3) 将ONNX模型转换为TNN模型

目前CNN模型有多种部署方式,可以采用TNN,MNN,NCNN,以及TensorRT等部署工具,鄙人采用TNN进行C/C++端上部署

TNN转换工具:

baa5c3a1527340d89194ec142af6aba0.png​​​​

 模型转换成功后,会得到两个TNN模型,一个描述模型结构的*.tnnproto的文件,一个是模型参数*.tnnmodel文件;下载到本地,后面C++/Android端上部署,需要加载*.tnnproto和*.tnnmodel文件进行模型推理。


4.年龄性别预测和识别模型Android部署

 (1) Android开发版本

开发前,请对齐Android Studio 开发版本,避免版本差异导致编译的异常; Android SDK和NDK相关版本信息,请参考: 

  • Android Studio 4.1.1
  • JDK Java1.8(Jave 8) JAVA_VERSION="1.8.0_242"
  • CMake 3.18.1

0649905fec8a43378987ba964fd62261.png

5bb2cd26f9a447c28ddd0185c19bed07.png 1f7f7098aca54d4b862abb2a9bb4e07f.png

(2) Android端上部署模型

项目实现了Android版本的年龄性别预测和识别Demo,部署框架采用TNN,支持多线程CPU和GPU加速推理,在普通手机上可以实时处理。项目Android源码,核心算法均采用C++实现,上层通过JNI接口调用.

如果你想在这个Android Demo部署你自己训练的分类模型,你可将训练好的Pytorch模型转换ONNX ,再转换成TNN模型,然后把TNN模型代替你模型即可。

  • 这是项目Android源码JNI接口 ,Java部分
package com.cv.tnn.model;

import android.graphics.Bitmap;

public class Detector {

    static {
        System.loadLibrary("tnn_wrapper");
    }


    /***
     * 初始化检测模型
     * @param det_model: 检测模型(不含后缀名)
     * @param cls_model: 识别模型(不含后缀名)
     * @param root:模型文件的根目录,放在assets文件夹下
     * @param model_type:模型类型
     * @param num_thread:开启线程数
     * @param useGPU:是否开启GPU进行加速
     */
    public static native void init(String det_model, String cls_model, String root, int model_type, int num_thread, boolean useGPU);

    /***
     * 返回检测和识别结果
     * @param bitmap 图像(bitmap),ARGB_8888格式
     * @param score_thresh:置信度阈值
     * @param iou_thresh:  IOU阈值
     * @return
     */
    public static native FrameInfo[] detect(Bitmap bitmap, float score_thresh, float iou_thresh);
}

  • 这是Android项目源码JNI接口 ,C++部分
#include <jni.h>
#include <string>
#include <fstream>
#include "src/object_detection.h"
#include "src/classification.h"
#include "src/Types.h"
#include "debug.h"
#include "android_utils.h"
#include "opencv2/opencv.hpp"
#include "file_utils.h"

using namespace dl;
using namespace vision;

static ObjectDetection *detector = nullptr;
static Classification *classifier = nullptr;

JNIEXPORT jint JNI_OnLoad(JavaVM *vm, void *reserved) {
    return JNI_VERSION_1_6;
}

JNIEXPORT void JNI_OnUnload(JavaVM *vm, void *reserved) {

}


extern "C"
JNIEXPORT void JNICALL
Java_com_cv_tnn_model_Detector_init(JNIEnv *env,
                                    jclass clazz,
                                    jstring det_model,
                                    jstring cls_model,
                                    jstring root,
                                    jint model_type,
                                    jint num_thread,
                                    jboolean use_gpu) {
    if (detector != nullptr) {
        delete detector;
        detector = nullptr;
    }
    std::string parent = env->GetStringUTFChars(root, 0);
    std::string det_model_ = env->GetStringUTFChars(det_model, 0);
    std::string cls_model_ = env->GetStringUTFChars(cls_model, 0);
    string det_model_file = path_joint(parent, det_model_ + ".tnnmodel");
    string det_proto_file = path_joint(parent, det_model_ + ".tnnproto");
    string cls_model_file = path_joint(parent, cls_model_ + ".tnnmodel");
    string cls_proto_file = path_joint(parent, cls_model_ + ".tnnproto");
    DeviceType device = use_gpu ? GPU : CPU;
    LOGW("parent     : %s", parent.c_str());
    LOGW("useGPU     : %d", use_gpu);
    LOGW("device_type: %d", device);
    LOGW("model_type : %d", model_type);
    LOGW("num_thread : %d", num_thread);
    ObjectDetectionParam model_param = FACE_MODEL;
    detector = new ObjectDetection(det_model_file,
                                   det_proto_file,
                                   model_param,
                                   num_thread,
                                   device);

    //ClassificationParam ClassParam = FACE_MASK_MODEL;
    ClassificationParam ClassParam = EYEGLASSES_MODEL;
    classifier = new Classification(cls_model_file,
                                    cls_proto_file,
                                    ClassParam,
                                    num_thread,
                                    device);
}

extern "C"
JNIEXPORT jobjectArray JNICALL
Java_com_cv_tnn_model_Detector_detect(JNIEnv *env, jclass clazz, jobject bitmap,
                                      jfloat score_thresh, jfloat iou_thresh) {
    cv::Mat bgr;
    BitmapToMatrix(env, bitmap, bgr);
    int src_h = bgr.rows;
    int src_w = bgr.cols;
    // 检测区域为整张图片的大小
    FrameInfo resultInfo;
    // 开始检测
    if (detector != nullptr) {
        detector->detect(bgr, &resultInfo, score_thresh, iou_thresh);
    } else {
        ObjectInfo objectInfo;
        objectInfo.x1 = 0;
        objectInfo.y1 = 0;
        objectInfo.x2 = (float)src_w;
        objectInfo.y2 = (float)src_h;
        objectInfo.label = 0;
        resultInfo.info.push_back(objectInfo);
    }

    int nums = resultInfo.info.size();
    LOGW("object nums: %d\n", nums);
    if (nums > 0) {
        // 开始检测
        classifier->detect(bgr, &resultInfo);
        // 可视化代码
        printf("sitting label:%d,score:%3.5f", resultInfo.label, resultInfo.score);
        //classifier->visualizeResult(bgr, &resultInfo);
    }
    //cv::cvtColor(bgr, bgr, cv::COLOR_BGR2RGB);
    //MatrixToBitmap(env, bgr, dst_bitmap);
    auto BoxInfo = env->FindClass("com/cv/tnn/model/FrameInfo");
    auto init_id = env->GetMethodID(BoxInfo, "<init>", "()V");
    auto box_id = env->GetMethodID(BoxInfo, "addBox", "(FFFFIF)V");
    auto ky_id = env->GetMethodID(BoxInfo, "addKeyPoint", "(FFF)V");
    jobjectArray ret = env->NewObjectArray(resultInfo.info.size(), BoxInfo, nullptr);
    for (int i = 0; i < nums; ++i) {
        auto info = resultInfo.info[i];
        env->PushLocalFrame(1);
        //jobject obj = env->AllocObject(BoxInfo);
        jobject obj = env->NewObject(BoxInfo, init_id);
        // set bbox
        //LOGW("rect:[%f,%f,%f,%f] label:%d,score:%f \n", info.rect.x,info.rect.y, info.rect.w, info.rect.h, 0, 1.0f);
        env->CallVoidMethod(obj, box_id, info.x1, info.y1, info.x2 - info.x1, info.y2 - info.y1,
                            info.category.label, info.category.score);
        // set keypoint
        for (const auto &kps : info.landmarks) {
            //LOGW("point:[%f,%f] score:%f \n", lm.point.x, lm.point.y, lm.score);
            env->CallVoidMethod(obj, ky_id, (float) kps.x, (float) kps.y, 1.0f);
        }
        obj = env->PopLocalFrame(obj);
        env->SetObjectArrayElement(ret, i, obj);
    }
    return ret;
}

(3) Android测试效果 

Android Demo在普通手机CPU/GPU上可以达到实时检测和识别效果;CPU(4线程)约30ms左右,GPU约25ms左右 ,基本满足业务的性能需求。

5c1c555f77564d0fb55f734876e149f1.gif  ed82b106efce4dbb8e316a1ceb310d90.gif

(4) 运行APP闪退:dlopen failed: library "libomp.so" not found

参考解决方法:
解决dlopen failed: library “libomp.so“ not found


5.项目源码下载

Android项目源码下载地址:年龄性别预测3:Android实现年龄性别预测和识别(含源码,可实时预测)

整套Android项目源码内容包含:

  1. 提供Android版本的人脸检测模型
  2. 提供年龄性别预测和识别Android Demo项目源码,源码可用于二次开发
  3. Android Demo在普通手机CPU/GPU上可以实时检测,CPU约30ms,GPU约25ms左右
  4. Android Demo支持图片,视频,摄像头测试
  5. 所有依赖库都已经配置好,可直接build运行,若运行出现闪退,请参考dlopen failed: library “libomp.so“ not found 解决。

年龄性别预测和识别Android APP Demo体验:https://download.csdn.net/download/guyuealian/88743711

如果你需要年龄性别预测和识别的训练代码,请参考 《Pytorch实现年龄性别预测和识别(含训练代码和数据)

 

  • 23
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI吃大瓜

尊重原创,感谢支持

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值