pytorch定义哪些层不进行梯度更新

import torch.nn as nn
class OurNet(nn.Module):
    def __init__(self):
        super(OurNet, self).__init__()
        self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.conv2a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.conv3a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
        self.b3_2 = ResBlock(256, 256, 256)

        self.not_training = [self.conv1a,self.b3_2]
    def forward(self, x):
        x=self.conv1a(x)
        x=self.conv2a(x)
        x=self.conv3a(x)
        return x
#1.对参数可以设置初始化、2.是否需要required_grad(bool值,true or false)、3.参数的学习率

    #这里定义哪些参数不更新
    def train(self, mode=True):
        super().train(mode) #继承原来pytorch提供的train函数,必须写
        for layer in self.not_training:    #第一次for循环layer=self.conv1a== nn.Conv2d(3, 64, 3, padding=1, bias=False)  #第二次for循环layer=self.b3_2== ResBlock(256, 256, 256)

            if isinstance(layer, torch.nn.Conv2d):
                layer.weight.requires_grad = False

            elif isinstance(layer, torch.nn.Module):  #例如class ResBlock(nn.Module):写的类,实例化的对象就属于torch.nn.Module
                for c in layer.children():
                    c.weight.requires_grad = False
                    if c.bias is not None:
                        c.bias.requires_grad = False
#获得模块后可以自定义初始化、冻结层

         for layer in self.modules():  #冻结层

            if isinstance(layer, torch.nn.BatchNorm2d):
                layer.eval()
                layer.bias.requires_grad = False
                layer.weight.requires_grad = False


         for layer in self.modules():   #自定义初始化

            if isinstance(layer, torch.nn.Conv2d):
                torch.nn.init.xavier_uniform_(layer.weight)
                torch.nn.init.kaiming_normal_(layer.bias)

        return


#设置参数的学习率
    def get_parameter_groups(self):
        groups = ([], [], [], [])
        print('======================================================')
        for m in self.modules():

            if (isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.modules.normalization.GroupNorm)):

                if m.weight.requires_grad:
                    if m in self.from_scratch_layers:
                        groups[2].append(m.weight)
                    else:
                        groups[0].append(m.weight)

                if m.bias is not None and m.bias.requires_grad:
                    if m in self.from_scratch_layers:
                        groups[3].append(m.bias)
                    else:
                        groups[1].append(m.bias)

        return groups





model.train()
param_groups = model.get_parameter_groups()
optimizer = torchutils.PolyOptimizer([
    {'params': param_groups[0], 'lr': args.lr, 'weight_decay': args.wt_dec},
    {'params': param_groups[1], 'lr': 2*args.lr, 'weight_decay': 0},
    {'params': param_groups[2], 'lr': 10*args.lr, 'weight_decay': args.wt_dec},
    {'params': param_groups[3], 'lr': 20*args.lr, 'weight_decay': 0}
    ], lr=args.lr, weight_decay=args.wt_dec, max_step=max_step)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值