github上有两个版本的多任务训练分别是:
1、https://github.com/miraclewkf/multi-task-MXNet
2、mxnet自带的例子
第一个由于其数据迭代器是Image,可能会比较慢。
第二个的例子是mnist,需要自己修改数据迭代器。
这里主要记录基于ImageRecordIter迭代器的多任务训练。
1、数据制作
需要自己生成*.lst文件,里面内容如下:
index task1标签 task2标签 task3标签 图片路径(这行是说明,不需要写入,每一列用\t隔开)
2476 0.000000 0.000000 1.000000 photo_02_8159/00022552.jpg
7623 3.000000 2.000000 2.000000 photo_03_7397/00029434.jpg
14149 0.000000 0.000000 1.000000 photo_05_15560/00060839.jpg
6874 3.000000 1.000000 2.000000 photo_03_7397/00028414.jpg
6048 0.000000 0.000000 1.000000 photo_02_8159/00027259.jpg
14479 3.000000 3.000000 2.000000 photo_05_15560/00065068.jpg
10429 2.000000 0.000000 1.000000 photo_04_15224/00040186.jpg
6949 3.000000 0.000000 1.000000 photo_03_7397/00028521.jpg
81 3.000000 3.000000 2.000000 photo_01_19992/00002536.jpg
11725 2.000000 0.000000 1.000000 photo_05_15560/00051778.jpg
1517 2.000000 3.000000 2.000000 photo_02_8159/00021245.jpg
具体是生成方法可以参考mxnet提供的im2rec.py,可以自己写一个make_list函数。
生成*.rec文件。这个文件可以用im2rec.py生成,同时需要把pack-label设置为True。
2、修改模型结构
添加3个mx.symbol.SoftmaxOutput损失函数(因为我这边是3个任务):
fc1 = mx.symbol.FullyConnected(data=flat, num_hidden=5, name='fc1') #任务1 有5个类别
fc2 = mx.symbol.FullyConnected(data=flat, num_hidden=15, name='fc2') #任务2 有15个类别
fc3 = mx.symbol.FullyConnected(data=flat, num_hidden=3, name='fc3') #任务3 有3个类别
#分别为这三个任务添加softmax损失函数,注意每个函数的名称,后面会用到
s1 =