最近博主在尝试多任务网络,简而言之就是网络中有一个backbone和多个head,每个head对应不同的任务。训练多任务网络,有一种训练方法是固定住backbone,每个head单独训练,这样的话head之间互不影响 ,是有利于提高单个任务的精度的。
但是博主在写代码时却发现,虽然已经小心翼翼了,但是head之间依然相互影响了,于是开始了漫长的debug过程。。。。。
以下是博主训练时的伪代码~
import pyorch
class MTL(nn.Module):
self.backbone = resnet18()
self.head1 = Linear()
self.head2 = Linear()
def forward(x, task):
if task = "1":
return self.head1(self.backbone(x))
elif task = "2":
return self.head2(self.backbone(x))
model = MTL()
model.load_weight()
## check pretained accuracy
task_1_pred = []
task_2_pred = []
model.eval()
for data in test_dataloader:
x