深度学习在训练时更新和保存最佳训练结果的方法(字典方法,本地保存方法,模型深拷贝方法)

1.用参数字典 model.state_dict()更新最优参数

best_state_dict = model.state_dict()  # 训练前
best_state_dict = model.state_dict()  # 训练时更新最优state_dict

完整代码:

 # 初始化一个变量来保存最优的state_dict
  best_state_dict = model.state_dict()
  for epoch in range(epochs):
      model.train()
      # 训练集上训练模型权重
      for data, targets in tqdm.tqdm(train_dataloader):
          # 把数据加载到GPU上
          data = data.to(devices[0])
          targets = targets.to(devices[0])

          # 前向传播
          preds = model(data)
          loss = criterion(preds, targets)

          # 反向传播
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

      # 测试集上评估模型性能
      model.eval()
      num_correct = 0
      num_samples = 0
      with torch.no_grad():
          for x, y in tqdm.tqdm(test_dataloader):
              x = x.to(devices[0])
              y = y.to(devices[0])
              preds = model(x)
              predictions = preds.max(1).indices  # 返回每一行的最大值和该最大值在该行的列索引
              num_correct += (predictions == y).sum()
              num_samples += predictions.size(0)
          acc = (num_correct / num_samples).item()
          if acc > best_acc:
              best_acc = acc
              best_epoch = epoch+1
              # 保存模型最优准确率的参数
              best_state_dict = model.state_dict()  # 更新最优state_dict
      model.train()
  # 训练结束保存
  torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")

2.训练过程中保存最优参数

if acc > best_acc:
    best_acc = acc
    best_epoch = epoch+1
    torch.save(best_state_dict, f"weights/{model_name}_{epochs}_{best_acc}.pth")

3.对模型深拷贝方法保存最优模型

深拷贝方法介绍

copy模块可以用来创建一个对象的深拷贝。这意味着复制后的模型和原始模型是完全独立的,包括它们的参数。

import torch  
import copy  
import torch.nn as nn  
  
# 假设我们有一个模型实例  
original_model = nn.Sequential(  
    nn.Linear(10, 5),  
    nn.ReLU(),  
    nn.Linear(5, 2)  
)  
  
# 复制模型  
model_copy = copy.deepcopy(original_model)

深拷贝方法保存最优模型

best_model = copy.deepcopy(model.state_dict())  # 训练前
best_model = copy.deepcopy(model.state_dict())  # 训练时更新最优state_dict

代码案例:

   def fit_zsl(self):
        best_acc = 0
        mean_loss = 0
        last_loss_epoch = 1e8
        # 定义best_model
        best_model = copy.deepcopy(self.model.state_dict())
        for epoch in range(self.nepoch):
            for i in range(0, self.ntrain, self.batch_size):
                self.model.zero_grad()
                batch_input, batch_label = self.next_batch(self.batch_size)
                self.input.copy_(batch_input)
                self.label.copy_(batch_label)

                inputv = Variable(self.input)
                labelv = Variable(self.label)
                output = self.model(inputv)
                loss = self.criterion(output, labelv)
                mean_loss += loss.item()
                loss.backward()
                self.optimizer.step()
            acc = self.val(
                self.test_unseen_feature,
                self.test_unseen_label,
                self.unseenclasses,
            )
            if acc > best_acc:
                best_acc = acc
                # 更新best_model
                best_model = copy.deepcopy(self.model.state_dict())
        #训练完毕本地保存
		torch.save(best_model.state_dict(), f"weights/{self.nepoch}_{best_acc}.pth")
        return best_acc, best_model
  • 17
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

computer_vision_chen

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值