MMoE模型keras源码
https://github.com/drawbridge/keras-mmoe
所用census-income数据集下载地址
https://github.com/drawbridge/keras-mmoe
1.数据处理:
常用的数据集可以直接调用mindspore.dataset接口实现,非常的方便。其使用方法可在ms官网编程指南中查看。其他的数据集可使用mindrecord接口,生成mindrecord格式数据,读写非常高效,具有很好的性能,但是生成的mindrecord格式的文件要比源文件大很多。因此是非常大的数据集时,也可以使用mindspore.dataset.GeneratorDataset进行自定义加载。
2.网络迁移:
mindspore在构建网络时,在init()函数定义所需的算子、层等。在construct()函数中搭建网络的前向结构。其中,在迁移的过程中,涉及到算子的映射,以下链接是pytorch和tensorflow与mindspore算子的映射:
API映射 — MindSpore master documentation
3.损失函数:
本次迁移模型时,网络涉及到多输出、数据涉及到多标签因此在自定义损失函数时,将网络和损失参照以下链接多损失的使用方法进行自定义:
损失函数 — MindSpore master documentation
4.训练:
训练主要分为:定义网络,加载处理好的数据集,定义损失函数,定义优化器,正常情况下是可以直接调用model.train将网络与损失函数、优化器封装起来进行训练的。但在本次训练的过程中因为损失函数的数据不是data,label。因此使用自定义的WithLossCell将网络和损失函数结合使用,调用TrainOneStepCell将优化器和网络、损失函数结合起来,最终将它们传入model.train进行训练。
5.CallBack:
一般官方定义好的Callback可以很好地将训练的loss及per step time返回。
6.评估:
在本次模型迁移时,主要使用AUC作为评估指标。因为是多loss的。因此不能根据model.eval进行验证。使用了以下链接中网络的循环实现:
official/cv/densenet/eval.py · MindSpore/models - Gitee.com
7.modelarts训练:
在云上训练和本地最大的差别就是需要调用moxing接口将数据集等文件从obs上传到catch容器中。在完成训练之后将得到的ckpt文件使用moxing将它从catch中拷贝到obs桶里边。适配问题可以参考model_zoo里AlxeNet的适配: