将LibMTL应用于自定义模型

LibMTL 是一个基于 PyTorch 构建的用于多任务学习的开源库。

官方教程例子具体说明

使用的是office的例子,这里只有一个任务,但是使用的是自定义的共享网络,方便按照这个写自定义网络

#这个函数用于添加你需要修改的参数
def parse_args(parser):
    parser.add_argument('--dataset', default='office-31', type=str, help='office-31, office-home')
    parser.add_argument('--bs', default=64, type=int, help='batch size')
    parser.add_argument('--epochs', default=100, type=int, help='training epochs')
    parser.add_argument('--dataset_path', default='/', type=str, help='dataset path')
    return parser.parse_args()

def main(params):
    #params.arch='HPS'#多任务结构params.arch:['HPS', 'Cross_stitch', 'MTAN', 'CGC', 'PLE', 'MMoE', 'DSelect_k', 'DIY']
    #params.weighting='EW'#loss权重参数:kwargs:['EW', 'UW', 'GradNorm', 'GLS', 'RLW', 'MGDA', 'IMTL','PCGrad', 'GradVac', 'CAGrad', 'GradDrop', 'DWA', 'DIY']
    #params.optim='adam'#优化器:optim_param:['adam', 'sgd', 'adagrad', 'rmsprop']
    #params.scheduler=None#学习率参数:scheduler_param['step', 'cos', 'exp']
    kwargs, optim_param, scheduler_param = prepare_args(params)#多任务结构,优化器,学习率优化
    
    #获取数据集
    if params.dataset == 'office-31':
        task_name = ['amazon', 'dslr', 'webcam']
        class_num = 31
    elif params.dataset == 'office-home':
        task_name = ['Art', 'Clipart', 'Product', 'Real_World']
        class_num = 65
    else:
        raise ValueError('No support dataset {}'.format(params.dataset))
    
    # define tasks,这里可以定义任务
   
    task_dict = {task: {'metrics': ['Acc'],#任务的评价指标
                       'metrics_fn': AccMetric(),#评价指标的计算类
                       'loss_fn': CELoss(),#损失函数
                       'weight': [1]} #评价指标权重,用于后边最佳参数的保存
                       for task in task_name}
    
    # prepare dataloaders
    data_loader, _ = office_dataloader(dataset=params.dataset, batchsize=params.bs, root_path=params.dataset_path)
    train_dataloaders = {task: data_loader[task]['train'] for task in task_name}
    val_dataloaders = {task: data_loader[task]['val'] for task in task_name}
    test_dataloaders = {task: data_loader[task]['test'] for task in task_name}
    
    # 自定义共享网络,就是常规的torch结构
    class Encoder(nn.Module):
        def __init__(self):
            super(Encoder, self).__init__()
            hidden_dim = 512
            self.resnet_network = resnet18(pretrained=True)
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.hidden_layer_list = [nn.Linear(512, hidden_dim),
                                      nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.5)]
            self.hidden_layer = nn.Sequential(*self.hidden_layer_list)

            # initialization
            self.hidden_layer[0].weight.data.normal_(0, 0.005)
            self.hidden_layer[0].bias.data.fill_(0.1)
            
        def forward(self, inputs):
            out = self.resnet_network(inputs)
            out = torch.flatten(self.avgpool(out), 1)
            out = self.hidden_layer(out)
            return out
    #用于每个网络各自的tower层
    decoders = nn.ModuleDict({task: nn.Linear(512, class_num) for task in list(task_dict.keys())})
    #多任务网络,直接使用LibMTL的Trainer类,具体可以基于这个类进行修改
    officeModel = Trainer(task_dict=task_dict, 
                          weighting=weighting_method.__dict__[params.weighting], 
                          architecture=architecture_method.__dict__[params.arch], 
                          encoder_class=Encoder, 
                          decoders=decoders,
                          rep_grad=params.rep_grad,
                          multi_input=params.multi_input,
                          optim_param=optim_param,
                          scheduler_param=scheduler_param,
                          save_path=params.save_path,
                          load_path=params.load_path,
                          **kwargs)
    
    if params.mode == 'train':#训练
        officeModel.train(train_dataloaders=train_dataloaders, 
                          val_dataloaders=val_dataloaders,
                          test_dataloaders=test_dataloaders, 
                          epochs=params.epochs)
    elif params.mode == 'test':#测试
        officeModel.test(test_dataloaders)
    else:
        raise ValueError
    
if __name__ == "__main__":
    params = parse_args(LibMTL_args)
    # set device
    set_device(params.gpu_id)
    # set random seed
    set_random_seed(params.seed)
    main(params)

task_dict具体解析

这里使用nyu里面的例子,因为这里有三个任务

task_dict = {'segmentation': {'metrics':['mIoU', 'pixAcc'], #任务1,评价指标
                              'metrics_fn': SegMetric(),#评价指标对用函数
                              'loss_fn': SegLoss(),#损失函数
                              'weight': [1, 1]}, #评价指标对应权重
                 'depth': {'metrics':['abs_err', 'rel_err'], #任务2,评价指标
                           'metrics_fn': DepthMetric(),
                           'loss_fn': DepthLoss(),#损失函数
                           'weight': [0, 0]},
                 'normal': {'metrics':['mean', 'median', '<11.25', '<22.5', '<30'], #任务3,评价指标
                            'metrics_fn': NormalMetric(),
                            'loss_fn': NormalLoss(),#损失函数
                            'weight': [0, 0, 1, 1, 1]}}

AbsMetric()

这里用于计算评价指标的类基于AbsMetric(),我们可以基于这个类根据自己的需求进行更改,这里的提供的基础方法都是必须的

class AbsMetric(object):
    r"""An abstract class for the performance metrics of a task. 

    Attributes:
        record (list): A list of the metric scores in every iteration.
        bs (list): A list of the number of data in every iteration.
    """
    #初始化参数
    def __init__(self):
        self.record = []
        self.bs = []
        
    #每次batch结束都会更新,你可以记录每次batch后需要记录的数据
    @property
    def update_fun(self, pred, gt):
        r"""Calculate the metric scores in every iteration and update :attr:`record`.

        Args:
            pred (torch.Tensor): The prediction tensor.
            gt (torch.Tensor): The ground-truth tensor.
        """
        pass
    
    #一次epoch后就会进行计算,这个函数可以用于计算每次epoch得到的评价指标
    @property
    def score_fun(self):
        r"""Calculate the final score (when a epoch ends).

        Return:
            list: A list of metric scores.
        """
        pass
    
    #每次epoch结束后出现,用于重新初始化参数
    def reinit(self):
        r"""Reset :attr:`record` and :attr:`bs` (when a epoch ends).
        """
        self.record = []
        self.bs = []
    

Metric() 改写举例

class AccMetric(AbsMetric):
    r"""Calculate the accuracy.
    """
    def __init__(self):
        super(AccMetric, self).__init__()
        
    def update_fun(self, pred, gt):#记录了每次batch正确的样本数
        r"""
        """
        pred = F.softmax(pred, dim=-1).max(-1)[1]
        self.record.append(gt.eq(pred).sum().item())
        self.bs.append(pred.size()[0])
        
    def score_fun(self):
        r"""
        """
        return [(sum(self.record)/sum(self.bs))]#计算一次epoch的准确率

Trainer

process_preds

库里没有写,但是如果需要对预测好的数据进行进一步处理,可以在这个函数里写

    def process_preds(self, preds, task_name=None):
        r'''The processing of prediction for each task. 

        - The default is no processing. If necessary, you can redefine this function. 
        - If ``multi_input``, ``task_name`` is valid, and ``preds`` with type :class:`torch.Tensor` is the prediction of this task.
        - otherwise, ``task_name`` is invalid, and ``preds`` is a :class:`dict` of name-prediction pairs of all tasks.

        Args:
            preds (dict or torch.Tensor): The prediction of ``task_name`` or all tasks.
            task_name (str): The string of task name.
        '''
        return preds

train

这个就是训练函数,具体的训练细节可以在这里更改

    def train(self, train_dataloaders, test_dataloaders, epochs, 
              val_dataloaders=None, return_weight=False):
        r'''The training process of multi-task learning.

        Args:
            train_dataloaders (dict or torch.utils.data.DataLoader): The dataloaders used for training. \
                            If ``multi_input`` is ``True``, it is a dictionary of name-dataloader pairs. \
                            Otherwise, it is a single dataloader which returns data and a dictionary \
                            of name-label pairs in each iteration.

            test_dataloaders (dict or torch.utils.data.DataLoader): The dataloaders used for validation or test. \
                            The same structure with ``train_dataloaders``.
            epochs (int): The total training epochs.
            return_weight (bool): if ``True``, the loss weights will be returned.
        '''
        train_loader, train_batch = self._prepare_dataloaders(train_dataloaders)#得到数据
        train_batch = max(train_batch) if self.multi_input else train_batch
        
        self.batch_weight = np.zeros([self.task_num, epochs, train_batch])#初始化,用于记录每个batch损失权重的变化
        self.model.train_loss_buffer = np.zeros([self.task_num, epochs])#初始化,记录每个epoch的损失
        for epoch in range(epochs):
            self.model.epoch = epoch
            self.model.train()
            self.meter.record_time('begin')#记录一次epoch所用时间
            for batch_index in range(train_batch):
                if not self.multi_input:#单任务
                    train_inputs, train_gts = self._process_data(train_loader)
                    train_preds = self.model(train_inputs)
                    train_preds = self.process_preds(train_preds)
                    train_losses = self._compute_loss(train_preds, train_gts)
                    self.meter.update(train_preds, train_gts)
                else:#多任务
                    train_losses = torch.zeros(self.task_num).to(self.device)#记录一次batch每个任务的损失
                    for tn, task in enumerate(self.task_name):#根据任务迭代训练
                        train_input, train_gt = self._process_data(train_loader[task])#得到该任务的数据集,这里和任务有关的内容就是之前的task_dict里面解析出来的
                        train_pred = self.model(train_input, task)#共享网络计算
                        train_pred = train_pred[task]#任务tower计算
                        train_pred = self.process_preds(train_pred, task)#进一步对预测值处理,具体处理可以修改process_preds()函数
                        train_losses[tn] = self._compute_loss(train_pred, train_gt, task)#计算损失
                        self.meter.update(train_pred, train_gt, task)#更新每次batch后需要记录的内容,来自于任务的Metric类

                self.optimizer.zero_grad()
                w = self.model.backward(train_losses, **self.kwargs['weight_args'])#这里有损失权重的更新
                if w is not None:
                    self.batch_weight[:, epoch, batch_index] = w
                self.optimizer.step()
            
            self.meter.record_time('end')#一次epoch结束
            self.meter.get_score()#计算每个任务的评价指标
            self.model.train_loss_buffer[:, epoch] = self.meter.loss_item#计算每个任务的平均损失
            self.meter.display(epoch=epoch, mode='train')#显示每个任务得到的评价指标
            self.meter.reinit()#将Metric类中的参数初始化
            
            if val_dataloaders is not None:#验证
                self.meter.has_val = True
                self.test(val_dataloaders, epoch, mode='val')
            self.test(test_dataloaders, epoch, mode='test')#测试
            if self.scheduler is not None:
                self.scheduler.step()
        self.meter.display_best_result()#所有训练结束,显示最佳结果
        if return_weight:
            return self.batch_weight

  • 4
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值