【经验分享】mindspore模型迁移

MMoE模型keras源码

https://github.com/drawbridge/keras-mmoe

所用census-income数据集下载地址

https://github.com/drawbridge/keras-mmoe

1.数据处理:

常用的数据集可以直接调用mindspore.dataset接口实现,非常的方便。其使用方法可在ms官网编程指南中查看。其他的数据集可使用mindrecord接口,生成mindrecord格式数据,读写非常高效,具有很好的性能,但是生成的mindrecord格式的文件要比源文件大很多。因此是非常大的数据集时,也可以使用mindspore.dataset.GeneratorDataset进行自定义加载。

2.网络迁移:

mindspore在构建网络时,在init()函数定义所需的算子、层等。在construct()函数中搭建网络的前向结构。其中,在迁移的过程中,涉及到算子的映射,以下链接是pytorch和tensorflow与mindspore算子的映射:

API映射 — MindSpore master documentation

3.损失函数:

本次迁移模型时,网络涉及到多输出、数据涉及到多标签因此在自定义损失函数时,将网络和损失参照以下链接多损失的使用方法进行自定义:

损失函数 — MindSpore master documentation

4.训练:

训练主要分为:定义网络,加载处理好的数据集,定义损失函数,定义优化器,正常情况下是可以直接调用model.train将网络与损失函数、优化器封装起来进行训练的。但在本次训练的过程中因为损失函数的数据不是data,label。因此使用自定义的WithLossCell将网络和损失函数结合使用,调用TrainOneStepCell将优化器和网络、损失函数结合起来,最终将它们传入model.train进行训练。

5.CallBack:

一般官方定义好的Callback可以很好地将训练的loss及per step time返回。

6.评估:

在本次模型迁移时,主要使用AUC作为评估指标。因为是多loss的。因此不能根据model.eval进行验证。使用了以下链接中网络的循环实现:

official/cv/densenet/eval.py · MindSpore/models - Gitee.com

7.modelarts训练:

在云上训练和本地最大的差别就是需要调用moxing接口将数据集等文件从obs上传到catch容器中。在完成训练之后将得到的ckpt文件使用moxing将它从catch中拷贝到obs桶里边。适配问题可以参考model_zoo里AlxeNet的适配:

models: Models of MindSpore - Gitee.com

  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值