Pytorch加载部分权重
在图像算法领域,我们有时会遇到训练了很长时间的网络,可能需要对原生网络做一些修改,修改后的网络再加载保存的网络权重时会报错。网络权重一般存储在字典中
1.修改网络层输出
比如在人脸检测项目中,已经训练好人脸框的回归,但是此时需要再加入人脸关键点。为了节约大量时间,我们可以加载部分权重。
加载的网络权重
if os.path.exists(self.load_params):
pretext_model = torch.load(self.load_params)
打印出来,会看到网络权重存储在一个字典中,需要修改哪一层,用字典的键索引值进行修改。比如原本输出层为4,我将网络输出层修改为14,又由于输出的都是坐标值,属于同一分布,所以我将原参4复制扩充为了14,效果非常好。
w = pretext_model["fc2.weight"]
b = pretext_model["fc2.bias"]
pretext_model["fc2.weight"