[pytorch]训练的时候固定部分层的参数
有些时候我们在写自己的网络的时候需要用到其他人或者pytorch的torchvision.models里预训练好的模型,但是我们可能希望固定一部分层的参数,在训练的时候不更新这些层的参数,这意味着我们希望反向传播计算梯度时,只计算剩余层的参数的梯度。
我们知道,网络中的所有操作对象都是Variable对象,在这篇文章中笔者将介绍如何利用Variable的requires_grad来实现我们的目的。
首先我们定义一个简单的网络,这个网络中使用torchvision.models的resnet50网络。
class model_resnet50_2(nn.Module):
'''
it can extract layer_features
'''
def __init__(self,num_classes):
super(model_resnet50_2,self).__init__()
resnet = models.resnet50(pretrained=True)
modules_1 = list(resnet.children())[:-6] # delete the last fc layer.
modules_2 = list(resnet.children())[-6:-4]
modules_3 = list(resnet.children())[-4:-2]
modules_4 = list(resnet.children())[-2:-1]
self.convnet_1 = nn.Sequential(*modules_1)
self.convnet_2 = nn.Sequential(*modules_2)
self.convnet_3 = nn.Sequential(*modules_3)
self.convnet_4 = nn.Sequential(*modules_4)
self.fc = nn.Linear(2048,num_classes)
def forward(self,x):
layer_features=[]
feature = self.convnet_1(x)
layer_features.append(feature)
feature = self.convnet_2(feature)
layer_features.append(feature)
feature = self.convnet_3(feature)
layer_features.append(feature)
feature = self.convnet_4(feature)
feature = feature.view(x.size(0), -1)
output = self.fc(feature)
return feature,output,layer_features
resnet_backbone = model_resnet50_2(64)
print(list(resnet_backbone.children()))
我们可以通过print来打印出这个网络的内部结构,由于输出太长这里就不贴上去了。
看这篇文章的小伙伴肯定都知道,pytorch的网络参数更新,是通过将普通的tensor变量转化为Variable变量实现反向梯度的计算和传播的,因此,只要我们让特定层的参数Variable变量requires_grad设置为false,自然而然在反向传播的时候就不会更新这些层已经学习过的参数了。
for k,v in resnet_backbone.named_parameters():
if 'bn' in k:
v.requires_grad = False
for k,v in resnet_backbone.named_parameters():
print('{}: {}'.format(k, v.requires_grad))
以下是打印的结果:
convnet_1.0.weight: True
convnet_1.1.weight: True
convnet_1.1.bias: True
convnet_2.0.0.conv1.weight: True
convnet_2.0.0.bn1.weight: False
convnet_2.0.0.bn1.bias: False
convnet_2.0.0.conv2.weight: True
convnet_2.0.0.bn2.weight: False
convnet_2.0.0.bn2.bias: False
convnet_2.0.0.conv3.weight: True
convnet_2.0.0.bn3.weight: False
convnet_2.0.0.bn3.bias: False
convnet_2.0.0.downsample.0.weight: True
convnet_2.0.0.downsample.1.weight: True
convnet_2.0.0.downsample.1.bias: True