pytorch保存模型

1.训练集的划分

    full_dataset = Dataset(root=args.root_dir, dtype=dtype)  
    train_size = int(0.8 * len(full_dataset))  # 训练集验证集比例=4:1
    val_size = len(full_dataset) - train_size

(1) torch.utils.data.random_split()

    # 按照给定的长度将数据集划分成没有重叠的新数据集组合。
    train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])

(2) torch.utils.data.Subset()

    # 获取指定一个索引序列对应的子数据集
    train_dataset = torch.utils.data.Subset(full_dataset, [:train_size])
    val_dataset = torch.utils.data.Subset(full_dataset, [train_size:])

2.数据集的加载

    dataloader = data.DataLoader(train_dataset, batch_size=1, shuffle=is_shuffle, num_workers=1)
    dataloader_val = data.DataLoader(val_dataset, batch_size=1, shuffle=is_shuffle, num_workers=1)

3.创建模型,模型训练

(1)每固定epoch后,保存模型
 if epoch % 50 == 0:
          torch.save(net.state_dict(),'%d.pth' % (epoch))
(2)在模型的验证集性能从上升到下降时,保存模型。
         min_loss_val = 10  # 任取一个大数
         best_model = None
         min_epoch = 100  # 训练至少需要的轮数
         for epoch in range(args.epochs):
             loss_val, loss_acc = train(epoch)
             if epoch > min_epoch and loss_val <= min_loss_val:
             min_loss_val = loss_val
             best_model = copy.deepcopy(model)
         model = best_model

(3)训练时,每个epoch保存在验证集的损失,当损失小于上个epoch时,保存模型,替换上一个epoch的模型

         min_loss = 100000 # 随便设置一个比较大的数
         for epoch in range(epochs):
             train()
             val_loss = val()
             if val_loss < min_loss:
                 min_loss = val_loss
                 print("save model")
                 torch.save(net.state_dict(),'model.pth')

4.加载模型参数:

model = my_CNN()
# 若加载使用多GPU训练的模型参数,必须有,否则会出错:RuntimeError: Error(s) in loading state_dict for my_CNN:Missing key(s) in state_dict: xxxxxxxx,Unexpected key(s) in state_dict: xxxxxxxxxx
model = nn.DataParallel(model,device_ids=[1, 2, 3])
model.to(device)
model.load_state_dict(torch.load('pre_trained_cnn.pth'))
  • 0
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值