在torch模型中有一个model.parameters()函数,可以是一个生成器,可以迭代返回模型的所有参数,利用这个函数可以迭代修改里面的参数从而达到weight不更新的目标。
代码(以resnet50为例):
model = models.resnet50(pretrained=True)
#目标冻结层数
target_frozennum = 40
cnt = 0
for param in model.parameters():
cnt += 1
if cnt == target_frozennum:
break
param.requires_grad = False