深度学习【22】Mxnet多任务(multi-task)训练

github上有两个版本的多任务训练分别是:
1、https://github.com/miraclewkf/multi-task-MXNet
2、mxnet自带的例子
第一个由于其数据迭代器是Image,可能会比较慢。
第二个的例子是mnist,需要自己修改数据迭代器。
这里主要记录基于ImageRecordIter迭代器的多任务训练。

1、数据制作
需要自己生成*.lst文件,里面内容如下:


index   task1标签   task2标签    task3标签    图片路径(这行是说明,不需要写入,每一列用\t隔开)
2476    0.000000    0.000000    1.000000    photo_02_8159/00022552.jpg 
7623    3.000000    2.000000    2.000000    photo_03_7397/00029434.jpg
14149   0.000000    0.000000    1.000000    photo_05_15560/00060839.jpg
6874    3.000000    1.000000    2.000000    photo_03_7397/00028414.jpg
6048    0.000000    0.000000    1.000000    photo_02_8159/00027259.jpg
14479   3.000000    3.000000    2.000000    photo_05_15560/00065068.jpg
10429   2.000000    0.000000    1.000000    photo_04_15224/00040186.jpg
6949    3.000000    0.000000    1.000000    photo_03_7397/00028521.jpg
81      3.000000    3.000000    2.000000    photo_01_19992/00002536.jpg
11725   2.000000    0.000000    1.000000    photo_05_15560/00051778.jpg
1517    2.000000    3.000000    2.000000    photo_02_8159/00021245.jpg

具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。
生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。
2、修改模型结构
添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):

    fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别
    fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别
    fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别
    #分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到
    s1 = 
  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值