[pytorch]训练的时候固定部分层的参数

[pytorch]训练的时候固定部分层的参数

有些时候我们在写自己的网络的时候需要用到其他人或者pytorch的torchvision.models里预训练好的模型,但是我们可能希望固定一部分层的参数,在训练的时候不更新这些层的参数,这意味着我们希望反向传播计算梯度时,只计算剩余层的参数的梯度。
我们知道,网络中的所有操作对象都是Variable对象,在这篇文章中笔者将介绍如何利用Variable的requires_grad来实现我们的目的。

首先我们定义一个简单的网络,这个网络中使用torchvision.models的resnet50网络。

class model_resnet50_2(nn.Module):
    '''
    it can extract layer_features
    '''
    def __init__(self,num_classes):
        super(model_resnet50_2,self).__init__()

        resnet = models.resnet50(pretrained=True)
        modules_1 = list(resnet.children())[:-6]     # delete the last fc layer. 
        modules_2 = list(resnet.children())[-6:-4]
        modules_3 = list(resnet.children())[-4:-2]
        modules_4 = list(resnet.children())[-2:-1]
   
        self.convnet_1 = nn.Sequential(*modules_1)
        self.convnet_2 = nn.Sequential(*modules_2)
        self.convnet_3 = nn.Sequential(*modules_3)
        self.convnet_4 = nn.Sequential(*modules_4)
        self.fc = nn.Linear(2048,num_classes)

    def forward(self,x):
        layer_features=[]      
        feature = self.convnet_1(x)
        layer_features.append(feature)
        feature = self.convnet_2(feature)
        layer_features.append(feature)
        feature = self.convnet_3(feature)
        layer_features.append(feature)
        feature = self.convnet_4(feature)
        feature = feature.view(x.size(0), -1)
        output = self.fc(feature)
        return feature,output,layer_features

resnet_backbone = model_resnet50_2(64)
print(list(resnet_backbone.children()))

我们可以通过print来打印出这个网络的内部结构,由于输出太长这里就不贴上去了。

看这篇文章的小伙伴肯定都知道,pytorch的网络参数更新,是通过将普通的tensor变量转化为Variable变量实现反向梯度的计算和传播的,因此,只要我们让特定层的参数Variable变量requires_grad设置为false,自然而然在反向传播的时候就不会更新这些层已经学习过的参数了。

for k,v in resnet_backbone.named_parameters():
    if 'bn' in k:
        v.requires_grad = False


for k,v in resnet_backbone.named_parameters():
    print('{}: {}'.format(k, v.requires_grad))

以下是打印的结果:

convnet_1.0.weight: True
convnet_1.1.weight: True
convnet_1.1.bias: True
convnet_2.0.0.conv1.weight: True
convnet_2.0.0.bn1.weight: False
convnet_2.0.0.bn1.bias: False
convnet_2.0.0.conv2.weight: True
convnet_2.0.0.bn2.weight: False
convnet_2.0.0.bn2.bias: False
convnet_2.0.0.conv3.weight: True
convnet_2.0.0.bn3.weight: False
convnet_2.0.0.bn3.bias: False
convnet_2.0.0.downsample.0.weight: True
convnet_2.0.0.downsample.1.weight: True
convnet_2.0.0.downsample.1.bias: True

已标记关键词 清除标记
©️2020 CSDN 皮肤主题: 数字20 设计师:CSDN官方博客 返回首页