PyTorch项目应用实例(八)固定权重|顺序训练网络

背景:需要将模型分层训练,不能同时训练。即固定一部分权重训练另一部分。

目录

一、多输出网络

1.1 heads多输出

1.2 最终的输出

二、loss更新网络的方法

2.1 loss位置

2.2 criterion定义

2.3 optimizer定义

2.4 定义顺序汇总

三、更改相应代码

3.1 定义需要优化的参数

3.2 更改学习率

3.3 网络loss及结构


一、多输出网络

参见 PyTorch应用实例(七)模型添加中继loss | 中继监督优化

在此基础上还需要对网络结构进行更改,features提取特征,head进一步提取,GAT_and_final_fcs进行最终输出。

1.1 heads多输出

其他不变,到group_linear之后就把相应的fcs输出

    # fixme  head forward
    def forward(self, x):
        # input x [batch_size, 2048, W=14, H=14]
        # conv from 2048 to Group_channel=512
        x=self.reduce_conv(x)
        # output x [B, group_channels=512, W=14, H=14]

        # output x [B, group_channels*groups=512*12=6144, W=14, H=14]
        x=self.bottle_nect(x)

        # output x [ Batch, n_groups=12, group_channels=512 ]
        x = self.gmp(x).view(x.size(0), self.groups,self.group_channels)

        # group linear from group to classes
        # [ Batch, n_groups=12, group_channels=512 ] ->  [ batch, n_classes, class_channels ]
        x=self.group_linear(x)

        # supplement structure for supplement loss
        # input size same as input of gat :[ batch, n_classes, class_channels ]
        # output size same as model output : [ batch , n_classes ]
        supplement_out=self.supplement_output_structure(x).view(x.size(0),x.size(1))

        # # GAT between classes
        # # input,output: [ batch, n_classes, class_channels ]
        # x = self.gat(x)
        #
        # # output  [ batch , n_classes ]
        # x= self.final_fcs(x).view(x.size(0),x.size(1))

        return x,supplement_out

1.2 最终的输出

在之前的基础上,创建新的class结构,网络为GAT与最终的输出

class Gat_and_final_fcs(nn.Module):
    def __init__(self, nclasses, class_channels):
        super(Gat_and_final_fcs, self).__init__()

        self.gat = BGATLayer(in_features=class_channels, out_features=class_channels, dropout=0, alpha=0.2)

        self.final_fcs=nn.Sequential(
                parallel_final_fcs_ResidualLinearBlock(n_classes=nclasses, class_channels=class_channels),
                parallel_output_linear(n_classes=nclasses,  class_channels=class_channels) )

    def forward(self,x):

        # GAT between classes
        # input,output: [ batch, n_classes, class_channels ]
        x = self.gat(x)

        # output  [ batch , n_classes ]
        x= self.final_fcs(x).view(x.size(0),x.size(1))

        return x

在最终的网络结构中加入下面内容:

        self.gat_and_final_fcs=Gat_and_final_fcs(nclasses=self.nclasses, class_channels=self.class_channels)

    def forward(self, x, inp):
        x = self.features(x)  # [B,2048,H,W] [2,2048,14,14]
        x,supplement_out = self.heads(x)
        x=self.gat_and_final_fcs(x)
        return x,supplement_out

二、loss更新网络的方法

2.1 loss位置

    criterion = util.get_criterion(args)
    model = util.get_model(args)
    # define optimizer
    optimizer = torch.optim.SGD(model.get_config_optim(args.LR, args.LRP),
                                momentum=args.MOMENTUM,
                                weight_decay=args.WEIGHT_DECAY)

通过 get_criterion定义loss

optimizer为随机提督下降算法,送入train

    trainer = T.Trainer(args, train_dataloader, val_dataloader, optimizer, model, criterion, lr_scheduler)
    trainer.run()

其中包括了 optimizer,criterion

2.2 criterion定义

即相应的loss定义:

def get_criterion(args):
    if args.LOSS_TYPE == 'MultiLabelSoftMarginLoss':
        criterion = nn.MultiLabelSoftMarginLoss()
    elif args.LOSS_TYPE == 'BCEWithLogitsLoss':
        criterion = nn.BCEWithLogitsLoss()
    elif args.LOSS_TYPE == 'DeepMarLoss':
        criterion = F.binary_cross_entropy_with_logits
    else:
        raise Exception()
    return criterion

我们常用的deepMarLoss,

区别是:softmax_cross_entropy_with_logits 要求传入的 labels 是经过 one_hot encoding 的数据,而 sparse_softmax_cross_entropy_with_logits 不需要。

https://www.jianshu.com/p/47172eb86b39

2.3 optimizer定义

    # define optimizer
    optimizer = torch.optim.SGD(model.get_config_optim(args.LR, args.LRP),
                                momentum=args.MOMENTUM,
                                weight_decay=args.WEIGHT_DECAY)

其中,需要更新的参数在get_config_optim之中

    def get_config_optim(self, lr, lrp):
        return [
            {'params': self.features.parameters(), 'lr': lrp},
            {'params': self.heads.parameters(), 'lr': lr},
        ]

2.4 定义顺序汇总

定义criterion = util.get_criterion(args),定义其loss类型,此变量无需任何变化。保留原结构。

定义optimizer,其中包含了优化器与需要优化的参数。

optimizer = torch.optim.SGD(model.get_config_optim(args.LR, args.LRP),
                            momentum=args.MOMENTUM,
                            weight_decay=args.WEIGHT_DECAY)

此变量涉及学习率以及loss需要更新的weight

三、更改相应代码

3.1 定义需要优化的参数

相应网络结构见上面 第一部分一、多输出网络

给相应的优化器设置对应的参数。

    def before_GAT_train_optim(self, lr):
        return [
            {'params': self.features.parameters(), 'lr': lr},
            {'params': self.heads.parameters(), 'lr': lr},
        ]

    def GAT_and_after_train_optim(self, lr):
        return [
            {'params': self.gat_and_final_fcs.parameters(), 'lr': lr},
        ]

    def final_total_model_train_optim(self,lr):
        return [
            {'params': self.features.parameters(), 'lr': lr},
            {'params': self.heads.parameters(), 'lr': lr},
            {'params': self.gat_and_final_fcs.parameters(), 'lr': lr},
        ]

读入optimizer并读入trainer

    before_GAT_optimizer = torch.optim.SGD(model.before_GAT_train_optim(lr=0.01),
                                           momentum=args.MOMENTUM,
                                           weight_decay=args.WEIGHT_DECAY)
    gat_and_after_optimizer = torch.optim.SGD(model.GAT_and_after_train_optim(lr=0.01),
                                              momentum=args.MOMENTUM,
                                              weight_decay=args.WEIGHT_DECAY)
    final_total_optimizer = torch.optim.SGD(model.final_total_model_train_optim(lr=0.001),
                                            momentum=args.MOMENTUM,
                                            weight_decay=args.WEIGHT_DECAY)

    trainer = T.Trainer(args=args, train_dataloader=train_dataloader, val_dataloader=val_dataloader,
                        before_GAT_optimizer=before_GAT_optimizer,gat_and_after_optimizer=gat_and_after_optimizer,
                        final_total_optimizer=final_total_optimizer,
                        model=model, criterion=criterion)

3.2 更改学习率

0-80优化器1,81-160优化器2, 161-240优化器3

其中,前30个epoch为基础学习率,往后每个退化0.9

run函数之中

                self.adjust_learning_rate(before_GAT_optimizer=self.before_GAT_optimizer,
                                          gat_and_after_optimizer=self.gat_and_after_optimizer,
                                          final_total_optimizer=self.final_total_optimizer, epoch=epoch)
    def adjust_learning_rate(self, before_GAT_optimizer,gat_and_after_optimizer,
                                    final_total_optimizer, epoch):
        """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
        if epoch<80:
            for i, param_group in enumerate(before_GAT_optimizer.param_groups):
                if epoch<30:
                    param_group['lr'] = 0.01
                    print('lr of weight group:',i, ':',param_group['lr'])
                else:
                    param_group['lr']=0.01*(0.9**(epoch-30))
                    print('lr of weight group:', i, ':', param_group['lr'])
        elif epoch<160:
            for i, param_group in enumerate(gat_and_after_optimizer.param_groups):
                if epoch<110:
                    param_group['lr'] = 0.01
                    print('lr of weight group:', i, ':', param_group['lr'])
                else:
                    param_group['lr'] = 0.01 * (0.9 ** (epoch - 110))
                    print('lr of weight group:', i, ':', param_group['lr'])
        else:
            for i, param_group in enumerate(final_total_optimizer.param_groups):
                if epoch<190:
                    param_group['lr'] = 0.0005
                    print('lr of weight group:', i, ':', param_group['lr'])
                else:
                    param_group['lr'] = 0.0005 * (0.9 ** (epoch - 190))
                    print('lr of weight group:', i, ':', param_group['lr'])

3.3 网络loss及结构

更改loss及结构。

            if (self.arch == 'group_clsgat_seq_train'):
                if self.loss_type == 'DeepMarLoss':
                    weights = self.deepmar_loss.weighted_label(target.detach())
                    if torch.cuda.is_available():
                        weights = weights.cuda()
                    output_loss = criterion(output, target, weight=weights)
                    supplement_loss=criterion(supplement_out, target, weight=weights)
                    # loss=output_loss+0.1*supplement_loss
                else:
                    output_loss=criterion(output, target)
                    supplement_loss=criterion(supplement_out, target)
                    # loss = output_loss+0.1*supplement_loss

            if epoch<80:
                before_GAT_optimizer.zero_grad()
                supplement_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                before_GAT_optimizer.step()
            elif epoch<160:
                gat_and_after_optimizer.zero_grad()
                output_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                gat_and_after_optimizer.step()
            else:
                final_total_optimizer.zero_grad()
                output_loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
                final_total_optimizer.step()
  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

祥瑞Coding

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值