具体报错:
报错原因:
创建的网络模型与导入的预训练参数在最后一层的全连接层是不匹配的,创建的模型为六分类任务,然后导入的参数是21841分类任务。
解决办法:
将导入模型的最后一层参数去掉,可以通提供print输出参数的的内容。来查看最后一层的名称。我的最后一层名为:head
#将保存的模型参数转换成字典的形式
checkpoint = torch.load(args.resume, map_location='cpu')
#去除最后一层全连接层的关键,实际是去掉字典的最后一组键值对
checkpoint = {k: v for k, v in checkpoint['model'].items() if 'head' not in k}
#加载参数至创建的模型
model_without_ddp.load_state_dict(checkpoint,strict=False)