pytorch固定部分参数进行网络指定层训练

pytorch固定部分参数进行网络指定层训练


问题描述

类似迁移学习,将模型权重加载进来之后,固定指定层的参数,进行剩余层的训练,并且固定层的参数不在更新,只更新训练的部分。

实现方法

1.设置条件,将满足条件的参数设置为不更新,代码如下:

for k,v in model2.named_parameters():
    if k in Layer1pre.keys():
        v.requires_grad = False

2.然后一定记得用filter过滤一下参数:

params = filter(lambda p: p.requires_grad, model.parameters())
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model2.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(params, lr=learning_rate)

如果少了filter步骤会导致固定的参数没有固定,输出的结果还是更新了如图:
在这里插入图片描述

3.检查一下是否设置成功,被固定的输出False,未被固定的输出True:

forname, value in model2.named_parameters():
    print(name, value.requires_grad)  # 打印所有参数requires_grad属性,True或False

如图所示,经过filter之后,第一层的参数已经被设置requires_grad=False,在之后的训练中只会更新第二层的参数。
在这里插入图片描述
经检验:模型二的第一层参数没有改变,如图所示:
在这里插入图片描述
完整代码如下:

# 加载model,model是自己定义好的模型
Layer1pre = torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer1.pt')
model2 = CNNLayer2(num_classes=10, aux_logits=True)
if use_gpu:
    model = model2.cuda()

# 读取参数
model2.load_state_dict(Layer1pre, strict=False)

for k,v in model2.named_parameters():
    if k in Layer1pre.keys():
        v.requires_grad = False

params = filter(lambda p: p.requires_grad, model.parameters())
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model2.parameters(), lr=learning_rate)
optimizer = torch.optim.Adam(params, lr=learning_rate)

for name, value in model2.named_parameters():
   print(name, value.requires_grad)#打印参数requires_grad属性,True或False

#TrainLoss,TrainAcc,EvalLoss,EvalAcc
loss_train_epochs = []  # 里面包含epoch个训练loss值
acc_train_epochs = []

loss_eval_epoch_all = []
acc_eval_epoch_all = []

loss_train_avg = []  # [0 for epoch in range(num_epochs)]
accuracy_train_avg = []
loss_eval_avg = []
accuracy_eval_avg = []


early_stopping = EarlyStopping(patience=patience, verbose=True, layer=2)

for epoch in range(num_epochs):
    epoch_start = time.time()
    print('*' * 15, f'Epoch: {epoch + 1}', '*' * 15)
    # print(f'epoch {epoch+1}')
    running_loss = 0.0
    running_acc = 0.0
    preds_all = []
    targets_all = []
    model2.train()
    for i, data in enumerate(train_loader, batch_size):
        img, label = data
        if use_gpu:
            img = img.cuda()
            label = label.cuda()
        # 向前传播

        logits= model2(img)
        # print('logits.type:',logits.type)

        # loss
        loss = criterion(logits, label)

        running_loss += loss.item()
        # running_loss2 += loss2.item()
        # print('loss:', loss, running_loss)
        _, pred = torch.max(logits, 1)

        # print('Lab result:', label)
        running_acc += (pred == label).float().mean()
        # 向后传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    TrainLoss = running_loss / len(train_loader)
    TrainAcc = running_acc / len(train_loader)

    loss_train_epochs.append(copy.deepcopy(TrainLoss))
    acc_train_epochs.append(copy.deepcopy(TrainAcc.numpy()))

    print(f'TrainLoss: {TrainLoss:.6f}')
    print(f'TrainAcc: {TrainAcc:.6f}')

    model2.eval()
    eval_loss = 0.
    eval_acc = 0.
    for data in eval_loader:
        img, label = data
        if use_gpu:
            img = img.cuda()
            label = label.cuda()

        with torch.no_grad():
            logits = model2(img)
            loss = criterion(logits, label)
        eval_loss += loss.item()
        _, pred = torch.max(logits, 1)
        eval_acc += (pred == label).float().mean()


    EvalLoss = eval_loss / len(eval_loader)
    EvalAcc = eval_acc / len(eval_loader)

    if best_acc2 < EvalAcc:
        best_acc2 = EvalAcc
        best_epoch = epoch + 1
        best_experiment = m + 1
        torch.save(model2.state_dict(), save_path)

    # bestacc = best_acc
    loss_eval_epoch_all.append(copy.deepcopy(EvalLoss))
    acc_eval_epoch_all.append(copy.deepcopy(EvalAcc.numpy()))

    # early_stopping needs the validation loss to check if it has decresed,
    # and if it has, it will make a checkpoint of the current model
    early_stopping(EvalLoss, model2,2)

    if early_stopping.early_stop:
        print("Early stopping")
        break
    # load the last checkpoint with the best model
    model2.load_state_dict(torch.load('./ResultData_earlystop/savemodel/checkpoint_model_layer2.pt'))
    
print("model.aux1.fc1.weight", model.aux1.fc1.weight)
print("model2.aux1.fc1.weight", model2.aux1.fc1.weight)
print("1-2", model2.aux1.fc1.weight-model2.aux1.fc1.weight)
  • 6
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

甜度超标°

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值