LibMTL 是一个基于 PyTorch 构建的用于多任务学习的开源库。
官方教程例子具体说明
使用的是office的例子,这里只有一个任务,但是使用的是自定义的共享网络,方便按照这个写自定义网络
#这个函数用于添加你需要修改的参数
def parse_args(parser):
parser.add_argument('--dataset', default='office-31', type=str, help='office-31, office-home')
parser.add_argument('--bs', default=64, type=int, help='batch size')
parser.add_argument('--epochs', default=100, type=int, help='training epochs')
parser.add_argument('--dataset_path', default='/', type=str, help='dataset path')
return parser.parse_args()
def main(params):
#params.arch='HPS'#多任务结构params.arch:['HPS', 'Cross_stitch', 'MTAN', 'CGC', 'PLE', 'MMoE', 'DSelect_k', 'DIY']
#params.weighting='EW'#loss权重参数:kwargs:['EW', 'UW', 'GradNorm', 'GLS', 'RLW', 'MGDA', 'IMTL','PCGrad', 'GradVac', 'CAGrad', 'GradDrop', 'DWA', 'DIY']
#params.optim='adam'#优化器:optim_param:['adam', 'sgd', 'adagrad', 'rmsprop']
#params.scheduler=None#学习率参数:scheduler_param['step', 'cos', 'exp']
kwargs, optim_param, scheduler_param = prepare_args(params)#多任务结构,优化器,学习率优化
#获取数据集
if params.dataset == 'office-31':
task_name = ['amazon', 'dslr', 'webcam']
class_num = 31
elif params.dataset == 'office-home':
task_name = ['Art', 'Clipart', 'Product', 'Real_World']
class_num = 65
else:
raise ValueError('No support dataset {}'.format(params.dataset))
# define tasks,这里可以定义任务
task_dict = {task: {'metrics': ['Acc'],#任务的评价指标
'metrics_fn': AccMetric(),#评价指标的计算类
'loss_fn': CELoss(),#损失函数
'weight': [1]} #评价指标权重,用于后边最佳参数的保存
for task in task_name}
# prepare dataloaders
data_loader, _ = office_dataloader(dataset=params.dataset, batchsize=params.bs, root_path=params.dataset_path)
train_dataloaders = {task: data_loader[task]['train'] for task in task_name}
val_dataloaders = {task: data_loader[task]['val'] for task in task_name}
test_dataloaders = {task: data_loader[task]['test'] for task in task_name}
# 自定义共享网络,就是常规的torch结构
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
hidden_dim = 512
self.resnet_network = resnet18(pretrained=True)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.hidden_layer_list = [nn.Linear(512, hidden_dim),
nn.BatchNorm1d(hidden_dim), nn.ReLU(), nn.Dropout(0.5)]
self.hidden_layer = nn.Sequential(*self.hidden_layer_list)
# initialization
self.hidden_layer[0].weight.data.normal_(0, 0.005)
self.hidden_layer[0].bias.data.fill_(0.1)
def forward(self, inputs):
out = self.resnet_network(inputs)
out = torch.flatten(self.avgpool(out), 1)
out = self.hidden_layer(out)
return out
#用于每个网络各自的tower层
decoders = nn.ModuleDict({task: nn.Linear(512, class_num) for task in list(task_dict.keys())})
#多任务网络,直接使用LibMTL的Trainer类,具体可以基于这个类进行修改
officeModel = Trainer(task_dict=task_dict,
weighting=weighting_method.__dict__[params.weighting],
architecture=architecture_method.__dict__[params.arch],
encoder_class=Encoder,
decoders=decoders,
rep_grad=params.rep_grad,
multi_input=params.multi_input,
optim_param=optim_param,
scheduler_param=scheduler_param,
save_path=params.save_path,
load_path=params.load_path,
**kwargs)
if params.mode == 'train':#训练
officeModel.train(train_dataloaders=train_dataloaders,
val_dataloaders=val_dataloaders,
test_dataloaders=test_dataloaders,
epochs=params.epochs)
elif params.mode == 'test':#测试
officeModel.test(test_dataloaders)
else:
raise ValueError
if __name__ == "__main__":
params = parse_args(LibMTL_args)
# set device
set_device(params.gpu_id)
# set random seed
set_random_seed(params.seed)
main(params)
task_dict具体解析
这里使用nyu里面的例子,因为这里有三个任务
task_dict = {'segmentation': {'metrics':['mIoU', 'pixAcc'], #任务1,评价指标
'metrics_fn': SegMetric(),#评价指标对用函数
'loss_fn': SegLoss(),#损失函数
'weight': [1, 1]}, #评价指标对应权重
'depth': {'metrics':['abs_err', 'rel_err'], #任务2,评价指标
'metrics_fn': DepthMetric(),
'loss_fn': DepthLoss(),#损失函数
'weight': [0, 0]},
'normal': {'metrics':['mean', 'median', '<11.25', '<22.5', '<30'], #任务3,评价指标
'metrics_fn': NormalMetric(),
'loss_fn': NormalLoss(),#损失函数
'weight': [0, 0, 1, 1, 1]}}
AbsMetric()
这里用于计算评价指标的类基于AbsMetric(),我们可以基于这个类根据自己的需求进行更改,这里的提供的基础方法都是必须的
class AbsMetric(object):
r"""An abstract class for the performance metrics of a task.
Attributes:
record (list): A list of the metric scores in every iteration.
bs (list): A list of the number of data in every iteration.
"""
#初始化参数
def __init__(self):
self.record = []
self.bs = []
#每次batch结束都会更新,你可以记录每次batch后需要记录的数据
@property
def update_fun(self, pred, gt):
r"""Calculate the metric scores in every iteration and update :attr:`record`.
Args:
pred (torch.Tensor): The prediction tensor.
gt (torch.Tensor): The ground-truth tensor.
"""
pass
#一次epoch后就会进行计算,这个函数可以用于计算每次epoch得到的评价指标
@property
def score_fun(self):
r"""Calculate the final score (when a epoch ends).
Return:
list: A list of metric scores.
"""
pass
#每次epoch结束后出现,用于重新初始化参数
def reinit(self):
r"""Reset :attr:`record` and :attr:`bs` (when a epoch ends).
"""
self.record = []
self.bs = []
Metric() 改写举例
class AccMetric(AbsMetric):
r"""Calculate the accuracy.
"""
def __init__(self):
super(AccMetric, self).__init__()
def update_fun(self, pred, gt):#记录了每次batch正确的样本数
r"""
"""
pred = F.softmax(pred, dim=-1).max(-1)[1]
self.record.append(gt.eq(pred).sum().item())
self.bs.append(pred.size()[0])
def score_fun(self):
r"""
"""
return [(sum(self.record)/sum(self.bs))]#计算一次epoch的准确率
Trainer
process_preds
库里没有写,但是如果需要对预测好的数据进行进一步处理,可以在这个函数里写
def process_preds(self, preds, task_name=None):
r'''The processing of prediction for each task.
- The default is no processing. If necessary, you can redefine this function.
- If ``multi_input``, ``task_name`` is valid, and ``preds`` with type :class:`torch.Tensor` is the prediction of this task.
- otherwise, ``task_name`` is invalid, and ``preds`` is a :class:`dict` of name-prediction pairs of all tasks.
Args:
preds (dict or torch.Tensor): The prediction of ``task_name`` or all tasks.
task_name (str): The string of task name.
'''
return preds
train
这个就是训练函数,具体的训练细节可以在这里更改
def train(self, train_dataloaders, test_dataloaders, epochs,
val_dataloaders=None, return_weight=False):
r'''The training process of multi-task learning.
Args:
train_dataloaders (dict or torch.utils.data.DataLoader): The dataloaders used for training. \
If ``multi_input`` is ``True``, it is a dictionary of name-dataloader pairs. \
Otherwise, it is a single dataloader which returns data and a dictionary \
of name-label pairs in each iteration.
test_dataloaders (dict or torch.utils.data.DataLoader): The dataloaders used for validation or test. \
The same structure with ``train_dataloaders``.
epochs (int): The total training epochs.
return_weight (bool): if ``True``, the loss weights will be returned.
'''
train_loader, train_batch = self._prepare_dataloaders(train_dataloaders)#得到数据
train_batch = max(train_batch) if self.multi_input else train_batch
self.batch_weight = np.zeros([self.task_num, epochs, train_batch])#初始化,用于记录每个batch损失权重的变化
self.model.train_loss_buffer = np.zeros([self.task_num, epochs])#初始化,记录每个epoch的损失
for epoch in range(epochs):
self.model.epoch = epoch
self.model.train()
self.meter.record_time('begin')#记录一次epoch所用时间
for batch_index in range(train_batch):
if not self.multi_input:#单任务
train_inputs, train_gts = self._process_data(train_loader)
train_preds = self.model(train_inputs)
train_preds = self.process_preds(train_preds)
train_losses = self._compute_loss(train_preds, train_gts)
self.meter.update(train_preds, train_gts)
else:#多任务
train_losses = torch.zeros(self.task_num).to(self.device)#记录一次batch每个任务的损失
for tn, task in enumerate(self.task_name):#根据任务迭代训练
train_input, train_gt = self._process_data(train_loader[task])#得到该任务的数据集,这里和任务有关的内容就是之前的task_dict里面解析出来的
train_pred = self.model(train_input, task)#共享网络计算
train_pred = train_pred[task]#任务tower计算
train_pred = self.process_preds(train_pred, task)#进一步对预测值处理,具体处理可以修改process_preds()函数
train_losses[tn] = self._compute_loss(train_pred, train_gt, task)#计算损失
self.meter.update(train_pred, train_gt, task)#更新每次batch后需要记录的内容,来自于任务的Metric类
self.optimizer.zero_grad()
w = self.model.backward(train_losses, **self.kwargs['weight_args'])#这里有损失权重的更新
if w is not None:
self.batch_weight[:, epoch, batch_index] = w
self.optimizer.step()
self.meter.record_time('end')#一次epoch结束
self.meter.get_score()#计算每个任务的评价指标
self.model.train_loss_buffer[:, epoch] = self.meter.loss_item#计算每个任务的平均损失
self.meter.display(epoch=epoch, mode='train')#显示每个任务得到的评价指标
self.meter.reinit()#将Metric类中的参数初始化
if val_dataloaders is not None:#验证
self.meter.has_val = True
self.test(val_dataloaders, epoch, mode='val')
self.test(test_dataloaders, epoch, mode='test')#测试
if self.scheduler is not None:
self.scheduler.step()
self.meter.display_best_result()#所有训练结束,显示最佳结果
if return_weight:
return self.batch_weight