多任务模型工具包LibMTL使用记录

本文作者分享了如何在遇到困惑时安装和配置LibMTL库用于多任务学习模型,包括环境设置、包版本选择、参数初始化以及自定义模型的应用,着重提到了MMoE模型和参数调整的重要性。
摘要由CSDN通过智能技术生成

因为一些原因要写一个简单的多任务学习模型作为对比,这种情况下直接掉包比较方便,然而在网上找的时候感觉很混乱,后来师姐推荐了LibMTL库,据说很好用(得到过学弟的推荐)

然而不会用,连手册都用不好,摸索了几天总算跑通了,写一篇文章记录一下(防止忘了)

本文写作时间是2023年12月28日,如果后续更新本文应该也不会更新

首先是安装,LibMTL库的手册及github上提供了安装库的方法,然而有坑:

这是手册上对于安装环境的要求:

Python >= 3.7
torch >= 1.8.0
torchvision >= 0.9.0
numpy >= 1.20

请不要直接写,因为pytorch后续更新了,pytorch的后续更新改变了一个函数,该函数在LibMTL库中调用过,所以如果只遵守该环境(torch >= 1.8.0),那么一定无法运行
推荐新建一个环境然后按照手册上的标准安装

conda create -n libmtl python=3.8
conda activate libmtl
# pip install torch==1.8.0 torchvision==0.9.0 numpy==1.20

第三步安装pytorch的操作也要谨慎选择,因为国内直接这样安装并不能安装满足要求的pytorch,建议去pytorch官网 -Get Start - Previous Pytorch Versions
然后找到适配你的电脑的torch==1.8.0或者1.8.1(其他版本我没试,因为我不知道Pytorch废除那个函数到底是哪个版本的事情)
没记错的话,我用的应该是:

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge

完成环境配置后,进入环境,clone项目

git clone https://github.com/median-research-group/LibMTL.git
cd LibMTL
pip install -e .

我中途有几个包安装失败了,于是单独安装了一下

关于LibMTL的使用,官方提供了一个例子

这个例子已经很详细了,但是应用起来还是有些问题要解决。
首先是kwargs, optim_param, scheduler_param = prepare_args(params)
这一步不要省略,我原先想的是,既然要自己写,那么肯定参数由自己定,后来才发现,这个库要用到的参数太多了,用这一行代码可以取代掉绝大多数没啥用的参数,否则一个个输入这些参数会特别麻烦(其中很多参数是在其他任务中使用的),使用这一行代码最重要的一点就是,如果后面需要使用某些参数,你可以再输入覆盖掉,所以这是很重要的初始化,不能省略。
如果想知道这些参数是什么可以打印出来

例子中有这样一个代码:

def parse_args(parser):
    parser.add_argument('--aug', action='store_true', default=False, help='data augmentation')
    parser.add_argument('--train_mode', default='trainval', type=str, help='trainval, train')
    parser.add_argument('--train_bs', default=8, type=int, help='batch size for training')
    parser.add_argument('--test_bs', default=8, type=int, help='batch size for test')
    parser.add_argument('--epochs', default=200, type=int, help='training epochs')
    parser.add_argument('--dataset_path', default='/', type=str, help='dataset path')
    return parser.parse_args()

这段代码定义了一个插入参数的方法,是一个很好的示例,因为前面代码给的参数都是对应包里面需要使用到的功能,对于一些我们需要用到但是又与这个工具包没关系的参数,就可以用这种方式来获取,比如例子中的dataset都是根据自定义的参数来设计的。
然后是task_dict字典,表示你的任务,比如我的任务是分类和分割,我就只给了两项,一个segmentation,一个classification。字典的key并没有固定套路,只是一个名字,例子中用的是segmentation,depth和normal。字典对应的键值都是四项,这个比较固定,可以查看一下,这里的SegMetric()之类的都是有对应的函数的,根据自己的需求来选择,其中,weight和metrics是对应的,表示对于该结果的加权:

# define tasks
task_dict = {'segmentation': {'metrics':['mIoU', 'pixAcc'], 
                          'metrics_fn': SegMetric(),
                          'loss_fn': SegLoss(),
                          'weight': [1, 1]}, 
             'depth': {'metrics':['abs_err', 'rel_err'], 
                       'metrics_fn': DepthMetric(),
                       'loss_fn': DepthLoss(),
                       'weight': [0, 0]},
             'normal': {'metrics':['mean', 'median', '<11.25', '<22.5', '<30'], 
                        'metrics_fn': NormalMetric(),
                        'loss_fn': NormalLoss(),
                        'weight': [0, 0, 1, 1, 1]}}

我使用的是:

task_dict = {'segmentation': {'metrics': ['mIoU', 'pixAcc'],
                               'metrics_fn': SegMetric(num_classes=10),
                               'loss_fn': MultiClassDiceLoss(num_classes=10),
                               'weight': [1, 1]},
              'classification': {'metrics': ['Acc'],
                                 'metrics_fn': AccMetric(),
                                 'loss_fn': CELoss(),
                                 'weight': [1]}
              }

我模仿了原有的loss自己写了一个,继承自工具包中的Absloss,可以根据自己的需求自行定制修改这里。

然后是Trainer部分,大部分我没改,主要是process_preds方法需要修改。例子里面给的是一个上采样,我的是这样写的:

def process_preds(self, preds, img_size=(256, 256)):
    if 'classification' in preds:
        # 假设 preds['classification'] 是 [batch_size, 1, 32, 32] 形状
        preds_classification = torch.mean(preds['classification'], dim=[2, 3])  # 全局平均池化
        preds_classification = preds_classification.unsqueeze(1)  # [batch_size, 1]
        preds_opposite = 1 - preds_classification  # 另一类别的得分
        preds['classification'] = torch.cat([preds_opposite, preds_classification], dim=1).squeeze(2)

    # 处理分割任务
    if 'segmentation' in preds:
        # 上采样到原始图像尺寸
        preds['segmentation'] = F.interpolate(preds['segmentation'], img_size, mode='bilinear',
                                              align_corners=True)
    # print(preds['classification'].shape)
    # print(preds['segmentation'].shape)

    return preds

可以看出来,输出其实是preds,是一个字典,原先有多少个输出,字典就有多少个key,key的值就是dict里面设定的名称,所以例子中的key应该是segmentation,depth和normal。

后面运行没多少好说的,就注意一个:
arch这个表示你要用的模型,我要用MMoE(这个在github页面找,有名字的)所以在输入参数的时候就设定了:

--arch MMoE

例子里面用的是默认的arch,也就是默认的模型。
具体使用哪个模型可以根据自己的要求来挑选,注意,不同的模型需要的参数也不完全相同,如果用的不是默认的模型,那么就要添加对应的参数,替换掉prepare_args自动准备的参数。

  • 20
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值