关于深度学习模型训练中相关参数的相关知识

一.模型参数的加载

深度学习模型的参数加载是按照模型架构的参数去到权重文件中寻找,直到完全匹配,两者之间的数目需要完全匹配,不然load_state_dict()函数会报错。具体见下例:

该模型名字为angle_model,里面包含features、avgpool以及angle模块。其中features模块与vgg16的features模块一致。如果我们想要将vgg16的预训练的参数加载到该features中,首先我们需要将vgg16中与angle_model相匹配的参数给提取出来。我们打开vgg16的权重文件可以看到其ordereddict的键名为:
所以我们需要去除掉不属于features的权重。
del_k = []
change_k = []
for k, v in vgg_weight.items():
	if "classifier" in k:
		del_k.append(k)
	if "features" in k:
		change_k.append(k)
	for k in del_k:
		del vgg_weight[k]
	for k in change_k:
		k_s = k.split('.')
		vgg_weight[k_s[1] + "." + k_s[2]] = vgg_weight.pop(k)

上面的代码解释为:for k, v in vgg_weight.items()获得vgg_weight的键对值,k为键,v为值。然后剔除包含classifier字符的键对。由于我们是对angle_model的features模块的权重进行加载,故不可以使用model.load_state_dict(vgg_weight),而应该使用model.features.load_state_dict(vgg_weight)。如果使用上面的加载方式会导致加载参数时两者之间的不匹配。如果使用下面一中,需要对vgg_weight的键进行去除features.的操作。这是因为我们是使用model.features所以键名变为为:

而vgg_weight的键名以features开头加上x.weight(x为数字)。所以我们仍需要将vgg_weight的键名的features.部分去除。注意由于字典是不能改变键名的,所以可以将值弹出,重新赋予新键,即vgg_weight[k_s[1] + "." + k_s[2]] = vgg_weight.pop(k)

二.模型参数的固定

假如我们需要固定features模块的参数,我们可以这样做:

for key, value in model.named_parameters():
	if 'features' in key:
		value.requires_grad = False

或者可以写作:

for key, value in model.features.named_parameters():
		value.requires_grad = False
# 或者
for value in model.features.parameters():
		value.requires_grad = False

备注:

模型:

查看模型参数:model.state_dict(), 类型ordereddict
返回模型参数迭代器:model.parameters() 迭代器,使用for或者next查看
返回模型名字和参数迭代器:model.named_parameters() 迭代器,使用for或者next查看

查看模型结构:model._module, 类型ordereddict
返回子模块迭代器:model.children() 迭代器,使用for或者next查看
返回子模块名字和子模块迭代器:model.named_children() 迭代器,使用for或者next查看

权重文件:

权重文件为ordereddict,相关操作和字典类似

ordereddict:

查看键:dict.keys()
返回键对值迭代器:dict.items()
删除键对值:del dict[key]
替换键对值:vgg_weight[new_key] = vgg_weight.pop(old_key)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值