Datawhale 零基础入门CV赛事-Task4 模型训练与验证

一、构造验证集

1、过拟合:在模型的训练过程中,模型只能利用训练数据来进行训练,模型并不能接触到测试集上的样本。因此模型如果将训练集学的过好,模型就会记住训练样本的细节,导致模型在测试集的泛化效果较差,与之相反的是欠拟合。
解决过拟合的办法是:构建一个与测试集尽可能分布一致的样本集(可称为验证集),在训练过程中不断验证模型在验证集上的精度,并以此控制模型的训练。在一般情况下,参赛选手也可以自己在本地划分出一个验证集出来,进行本地验证。

训练集(Train Set):模型用于训练和调整模型参数;

验证集(Validation Set):用来验证模型精度和调整模型超参数;

测试集(Test Set):验证模型的泛化能力

本地划分验证集的方法:留出法、交叉验证法、自主采样法。

二、模型训练与验证

构造训练集和验证集;
每轮进行训练和验证,并根据最优验证集精度保存模型。

import torch.nn as nn
from torch.utils.data.dataset import Dataset

train_loader = torch.utils.data.DataLoader(
#     train_dataset,
    batch_size=10, 
    shuffle=True, 
    num_workers=10, 
)

val_loader = torch.utils.data.DataLoader(
#     val_dataset,
    batch_size=10, 
    shuffle=False, 
    num_workers=10, 
)

model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):
    print('Epoch: ', epoch)

train(train_loader, model, criterion, optimizer, epoch)
    val_loss = validate(val_loader, model, criterion)
    
    # 记录下验证集精度
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(model.state_dict(), './model.pt')

其中每一个Epoch的训练代码:

def train(train_loader, model, criterion, optimizer, epoch):
    # 切换模型为训练模式
    model.train()

for i, (input, target) in enumerate(train_loader):
        c0, c1, c2, c3, c4, c5 = model(data[0])
        loss = criterion(c0, data[1][:, 0]) + \
                criterion(c1, data[1][:, 1]) + \
                criterion(c2, data[1][:, 2]) + \
                criterion(c3, data[1][:, 3]) + \
                criterion(c4, data[1][:, 4]) + \
                criterion(c5, data[1][:, 5])
        loss /= 6
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

其中每个Epoch的验证代码如下:

def validate(val_loader, model, criterion):
    # 切换模型为预测模型
    model.eval()
    val_loss = []
    
 # 不记录模型梯度信息
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            c0, c1, c2, c3, c4, c5 = model(data[0])
            loss = criterion(c0, data[1][:, 0]) + \
                    criterion(c1, data[1][:, 1]) + \
                    criterion(c2, data[1][:, 2]) + \
                    criterion(c3, data[1][:, 3]) + \
                    criterion(c4, data[1][:, 4]) + \
                    criterion(c5, data[1][:, 5])
            loss /= 6
            val_loss.append(loss.item())
    return np.mean(val_loss)

三、模型保存与加载

在Pytorch中模型的保存和加载非常简单,比较常见的做法是保存和加载模型参数:

torch.save(model_object.state_dict(), 'model.pt')
model.load_state_dict(torch.load(' model.pt')) 
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值