- 在决定使用MNN实现在线训练之前,也比较了TNN/NCNN,发现目前各大端侧推理引擎的训练框架都不成熟,半斤八两的状态,可能都把精力放在推理和op支持上,但是端侧训练的需求真的少么?fine-tune在端侧应用难道不是刚需?
- 端侧推理的实现相对简单,MNN官方已有完善的文档参考。
- 这篇主要介绍基于MNN的深度学习端侧在线训练,以预编译动态库的方式引用MNN训练框架,并实现在Android端基于预训练模型的在线finetune全流程。
MNN在线训练思路梳理
-
在线训练的需求场景(why)
- 基于预训练模型迁移学习fine-tune(本篇示例场景)
- 使用MNN进行训练量化
- 端侧在线学习(如移动端,嵌入式设备)
-
使用MNN进行在线训练的两种方式(how)
- 使用MNNConverter将模型转换为可训练模型,端侧加载模型进行训练(本篇示例方式)
- 端侧利用MNN API从零搭建模型
- 参考https://www.yuque.com/mnn/cn/kgd9hd
-
基于MNN实现端侧Finetune所需步骤
预训练模型转换成可训练的模型
- 不同深度学习框架输出的预训练模型需要先转换为mnn格式
- 参考我的另一篇https://blog.csdn.net/hechao3225/article/details/114820905
- 将mnn格式的预训练模型转换为可训练模型
- 开启–forTraining保留BatchNorm,Dropout等训练过程中会用到的算子
- 如果你的模型中没有BN,Dropout等在转MNN推理模型时会被融合掉的算子,那么直接使用MNN推理模型也可以进行训练,不必重新进行转换
./MNNConvert --modelFile mobilenet_v2_1.0_224_frozen.pb --MNNModel mobilenet_v2_tfpb_train.mnn --framework TF --bizCode AliNNTest --forTraining
Android平台编译库
-
在线训练相关编译开关设置
- 根目录下CMakeList.txt编译开关
- 打开MNN_BUILD_TRAIN=ON
- tools/train下CMakeList.txt编译开关
- MNN_BUILD_TRAIN_MINI=ON : 对于移动端/嵌入式设备,建议设置 MNN_BUILD_TRAIN_MINI = ON,不编译内置的Dataset,Models
- MNN_USE_OPENCV=OFF : 部分 PC 上的 demo 有用到,暂时用不到
- 根目录下CMakeList.txt编译开关
-
具体步骤
-
配置NDK环境变量,指定NDK版本:在 .bashrc 或者 .bash_profile 中设置 NDK 环境变量,
export ANDROID_NDK=/home/goodix/code/fp_prebuilts/tools/android-ndk-r17b
-
mnn源码根目录执行
./schema/generate.sh
-
进入android目录
cd project/android
-
编译armv7动态库
mkdir build_32 && cd build_32 && ../build_32.sh
-
编译armv8动态库:
mkdir build_64 && cd build_64 && ../build_64.sh
-
-
Android平台在线训练依赖的三个库,如果仅作在线推理,只需要引用libMNN.so,如果需要在线训练,还需要libMNN_Express.so和libMNNTrain.so。
路径如下:
- project/android/build_64/libMNN.so
- project/android/build_64/libMNN_Express.so
- project/android/build_64/tools/train/libMNNTrain.so
-
官方文档参考:https://www.yuque.com/mnn/cn/build_android
在线迁移学习Finetune Demo实现
-
实现微调模型MyAnnTransferModule
-
对于finetune场景,我们不需要端侧从零搭建模型,只需要加载预训练模型,固定神经网络前面层的参数,仅对全连接层最后一层用于微调
-
需要继承Module类,并重写构造函数和onForward
class MyAnnTransferModule : public MNN::Express::Module { public: AlsAnnTransferModule(const char* fileName); virtual std::vector<MNN::Express::VARP> onForward(const std::vector<MNN::Express::VARP>& inputs) override; std::shared_ptr<Module> mFixedLayers; std::shared_ptr<Module> mFineTuneLayers; // add new layers for finetuning }; class MyAnnTransfer{ public: static void train(std::shared_ptr<MNN::Express::Module> model, std::vector<OneAlsData> trainData, std::vector<OneAlsData> testData); };
-
**[关键]**如何固定神经网络部分参数,仅对最后一层或最后几层微调?
-
通过netron模型可视化工具(或MNNConvert工具输出的模型json文件)查看最后一层的input.name
-
netron.app在线地址:https://netron.app/
-
示例,open module后点击最后一个Convolution,右边Input name即为最后一层的输入,我们需要通过整个name对模型分界
-
-
模型构造函数中loadMap加载模型,并使用input.name作为PipelineModule::extract的第二个参数,extract会保留除去最后一层的预训练模型
-
extract保留的部分为固定参数层,新初始化一个layer作为finetune层,仅注册finetune层用于训练,构造函数示例代码:
MyAnnTransferModule::MyAnnTransferModule(const char* fileName) { auto srcModelMap = Variable::loadMap(fileName); auto inputOutputs = Variable::getInputAndOutput(srcModelMap); auto input = inputOutputs.first.begin()->second; auto fixedOut = srcModelMap["Reshape45"]; // init a dense layer for finetuning // mFineTuneLayers.reset(NN::Linear(10, 5)); // use a conv layer for a dense layer NN::ConvOption option; option.channel = {10, 5}; mFineTuneLayers = std::shared_ptr<Module>(NN::Conv(option)); // get fixed layers from src module, set trainFlag=false mFixedLayers.reset(PipelineModule::extract({input}, {fixedOut}, false)); // only train finetuning layers registerModel({mFineTuneLayers}); }
-
然后重写onForward,固定层和微调层先后前向计算,需要注意微调层前向计算后需要进行_Convert和_Reshape操作,Reshape的维度信息仍然可以通过netron可视化工具查看,onForward函数示例代码:
std::vector<VARP> MyAnnTransferModule::onForward(const std::vector<VARP>& inputs) { auto fixedResult = mFixedLayers->forward(inputs[0]); auto result = _Reshape(_Convert(mFineTuneLayers->forward(fixedResult), NCHW), {-1, 5}); return {result}; }
-
-
-
实现train和test需要的dataset,由dataset创建dataloader
-
需要继承Dataset类,重写get()和size()两个虚函数,dataset与项目自己的数据格式有关,可参考官方MnistDataset
-
train函数中创建用于train和test的dataloader
auto dataset = MnistDataset::create(trainData, AlsDataset::Mode::TRAIN); const size_t batchSize = 1; const size_t numWorkers = 0; bool shuffle = true; auto dataLoader = std::shared_ptr<DataLoader>(dataset.createLoader(batchSize, true, shuffle, numWorkers)); size_t iterations = dataLoader->iterNumber(); auto testDataset = MnistDataset::create(testData, AlsDataset::Mode::TEST); const size_t testBatchSize = 1; const size_t testNumWorkers = 0; shuffle = false;
-
实现train过程(参考MnistUtils.cpp)
- 示例代码使用sgd优化器,学习率衰减等训练策略,可以根据自己需求使用MNN不同 API测试效果
- #ifdef DEBUG_GRAD宏包含了梯度校验的代码
for (int epoch = 0; epoch < 10; ++epoch) { model->clearCache(); exe->gc(Executor::FULL); exe->resetProfile(); { AUTOTIME; dataLoader->reset(); model->setIsTraining(true); Timer _100Time; int lastIndex = 0; int moveBatchSize = 0; for (int i = 0; i < iterations; i++) { auto trainData = dataLoader->next(); auto example = trainData[0]; moveBatchSize += example.first[0]->getInfo()->dim[0]; auto predict = model->forward(example.first[0]); auto loss = _MSE(predict, example.second[0]); //#define DEBUG_GRAD #ifdef DEBUG_GRAD { static bool init = false; if (!init) { init = true; std::set<VARP> para; example.first[0].fix(VARP::INPUT); newTarget.fix(VARP::CONSTANT); auto total = model->parameters(); for (auto p :total) { para.insert(p); } auto grad = OpGrad::grad(loss, para); total.clear(); for (auto iter : grad) { total.emplace_back(iter.second); } Variable::save(total, ".temp.grad"); } } #endif float rate = LrScheduler::inv(0.01, epoch * iterations + i, 0.0001, 0.75); sgd->setLearningRate(rate); if (moveBatchSize % (10 * batchSize) == 0 || i == iterations - 1) { std::cout << "epoch= " << (epoch) << std::endl; std::cout << moveBatchSize << " / " << dataLoader->size() << std::endl; std::cout << " lr= " << rate; std::cout << " time= " << (float)_100Time.durationInUs() / 1000.0f << " ms / " << (i - lastIndex) << " iter" << std::endl; std::cout.flush(); _100Time.reset(); lastIndex = i; } sgd->step(loss); } } }
-
保存训练模型
-
1和2的区别在于前者只保存参数,后者保存模型参数和结构
// 1. only save model parameters Variable::save(model->parameters(), "alsann.snapshot.mnn"); // 2. save model parameters and structure { model->setIsTraining(false); auto forwardInput = _Input({1, 1, 28, 28}, NC4HW4); forwardInput->setName("data"); auto predict = model->forward(forwardInput); predict->setName("prob"); Transformer::turnModelToInfer()->onExecute({predict}); Variable::save({predict}, "alsann.mnn"); }
-
如果需要train前加载模型参数:
// Load snapshot auto para = Variable::load("alsann.snapshot.mnn"); model->loadParameters(para);
-
-
注:MNN里已经实现了几个image相关的Dataset示例,如ImageDataset,MnistDataset
-
Finetue Demo编译配置
-
在main.cpp中创建模型和调用train函数,使用NDK编译成一个独立的二进制程序作为测试demo
- trainData和testData可以通过参数指定路径的方式传入
std::shared_ptr<Module> model(new MyAnnTransferModule(argv[1])); MyAnnTransferModule::train(model, trainData, testData);
-
我们将MNN编译的三个so以预编译共享库的方式加入Android.mk
-
定义三个库,MNN,MNN_express和MNN_train
LOCAL_PATH := $(call my-dir) include $(CLEAR_VARS) LOCAL_MODULE := MNN LOCAL_SRC_FILES := ${LOCAL_PATH}/../lib/lib64/libMNN.so include $(PREBUILT_SHARED_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := MNN_express LOCAL_SRC_FILES := ${LOCAL_PATH}/../lib/lib64/libMNN_Express.so include $(PREBUILT_SHARED_LIBRARY) include $(CLEAR_VARS) LOCAL_MODULE := MNN_train LOCAL_SRC_FILES := ${LOCAL_PATH}/../lib/lib64/libMNNTrain.so include $(PREBUILT_SHARED_LIBRARY)
-
引用三个共享库
LOCAL_SHARED_LIBRARIES := MNN MNN_express MNN_train
-
-
Android.mk添加新增的头文件和源文件
LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/core LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/expr LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/plugin LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/express LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/express/module LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/data LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/grad LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/optimizer LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/parameters LOCAL_C_INCLUDES += ${LOCAL_PATH}/../include/MNN/mnn_train/transformer CPP_LIST := $(wildcard $(LOCAL_PATH)/*.cpp) LOCAL_SRC_FILES := $(CPP_LIST:$(LOCAL_PATH)/%=%)
-
指定NDK版本编译
/home/goodix/code/fp_prebuilts/tools/android-ndk-r17b/ndk-build NDK_APPLICATION_MK=Application.mk APP_BUILD_SCRIPT=Android.mk APP_ABI=arm64-v8a NDK_PROJECT_PATH=./ -B
-
编译完成后将模型和demo二进制push到手机测试,push脚本示例
set -e adb root adb remount adb shell setenforce 0 # adb shell rm -rf /usr/bin/arm64-v8a adb push obj/local/arm64-v8a/ /usr/bin/ adb push /home/goodix/code/als/Hamilton_DL/test_code/mnn_demo/model/my_ann.mnn /usr/bin/arm64-v8a/my_ann.mnn adb shell chmod +x /usr/bin/arm64-v8a/ALS_ANN adb shell export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/bin/arm64-v8a:/vendor/lib64:/system/lib64:/system/lib64/vndk-29:/system/lib
遇到的问题
-
问题1:编译link阶段找不到智能指针相关的定义
undefined reference to `std::__1::__shared_weak_count::__release_weak()'
-
Application.mk已指定C++11以上版本
APP_CPPFLAGS := -frtti -fexceptions -std=c++14
-
原因是STL的版本必须指定为gnustl_shared,不能是成c++_static,c++_shared, stlport_shared等,这些版本本身不支持shared_ptr和function相关特性。
在gnustl版本中,shared_ptr定义在NDK根目录\sources\cxx-stl\gnu-libstdc++\4.8\include\memory文件中。
Application.mk修改
APP_STL := gnustl_shared
-
但是,ndk-r19c等高版本的STL版本不支持gnustl_shared,还需要切换ndk-17b版本编译
-
-
问题2:libMNN_Express.so链接不到函数定义
- 解决方法:对齐mnn编译脚本和demo编译脚本的stl版本,ndk版本