一、问题描述
在多卡加载与训练模型的时候,出现显卡内存不足的错误,batchsize放到很小以后观察发现第一张卡占用内存比其他几张大了很多。
#迁移学习
if cfg.feature_extract:
#print("feature_extract")
backbone_weight_path=cfg.backbone_weight_path
load_model = torch.load(backbone_weight_path)
backbone.load_state_dict(load_model, strict=False)
解决方法:
把预训练模型参数map到cpu上去。
修改代码
#迁移学习
if cfg.feature_extract:
#print("feature_extract")
backbone_weight_path=cfg.backbone_weight_path
load_model = torch.load(backbone_weight_path,map_location='cpu')
backbone.load_state_dict(load_model, strict=False)