一.模型参数的加载
深度学习模型的参数加载是按照模型架构的参数去到权重文件中寻找,直到完全匹配,两者之间的数目需要完全匹配,不然load_state_dict()函数会报错。具体见下例:
del_k = []
change_k = []
for k, v in vgg_weight.items():
if "classifier" in k:
del_k.append(k)
if "features" in k:
change_k.append(k)
for k in del_k:
del vgg_weight[k]
for k in change_k:
k_s = k.split('.')
vgg_weight[k_s[1] + "." + k_s[2]] = vgg_weight.pop(k)
上面的代码解释为:for k, v in vgg_weight.items()获得vgg_weight的键对值,k为键,v为值。然后剔除包含classifier字符的键对。由于我们是对angle_model的features模块的权重进行加载,故不可以使用model.load_state_dict(vgg_weight),而应该使用model.features.load_state_dict(vgg_weight)。如果使用上面的加载方式会导致加载参数时两者之间的不匹配。如果使用下面一中,需要对vgg_weight的键进行去除features.的操作。这是因为我们是使用model.features所以键名变为为:
而vgg_weight的键名以features开头加上x.weight(x为数字)。所以我们仍需要将vgg_weight的键名的features.部分去除。注意由于字典是不能改变键名的,所以可以将值弹出,重新赋予新键,即vgg_weight[k_s[1] + "." + k_s[2]] = vgg_weight.pop(k)
。
二.模型参数的固定
假如我们需要固定features模块的参数,我们可以这样做:
for key, value in model.named_parameters():
if 'features' in key:
value.requires_grad = False
或者可以写作:
for key, value in model.features.named_parameters():
value.requires_grad = False
# 或者
for value in model.features.parameters():
value.requires_grad = False
备注:
模型:
查看模型参数:model.state_dict(), 类型ordereddict
返回模型参数迭代器:model.parameters() 迭代器,使用for或者next查看
返回模型名字和参数迭代器:model.named_parameters() 迭代器,使用for或者next查看
查看模型结构:model._module, 类型ordereddict
返回子模块迭代器:model.children() 迭代器,使用for或者next查看
返回子模块名字和子模块迭代器:model.named_children() 迭代器,使用for或者next查看
权重文件:
权重文件为ordereddict,相关操作和字典类似
ordereddict:
查看键:dict.keys()
返回键对值迭代器:dict.items()
删除键对值:del dict[key]
替换键对值:vgg_weight[new_key] = vgg_weight.pop(old_key)