pytorch-训练模型通用流程。

pytorch-训练模型通用流程。

模型的训练基本上就是确定两个事,一个是优化器(torch.optim),一个是损失函数(nn.MSELoss, nn.CrossEntropyLoss, nn.NLLoss)。优化器:模型参数。损失函数:需要datasets。这要是为什么损失函数提供类接口原因。

from tqdm import tqdm
import torch
import time
import pickle

def train(epoches, model, dataloader, optim, criterion, model_save_path):
    loss_mark = float('inf')
    loss_store = []
    model_num = 0
    for _ in tqdm(range(epoches)):
        loss_sum = 0
        start = time.time()
        for batch_line, batch_category in dataloader:
            batch_pred = model(batch_line)
            loss = criterion(batch_pred, batch_category)
            optim.zero_grad()
            loss.backward()
            optim.step()
            loss_sum += loss.item()
        loss_store.append(loss_sum)
        if loss_sum  < loss_store:
            loss_store = loss_sum
            torch.save(f'{model_save_path}/{model_num}.pth', pickle_module=pickle, pickle_protocol=2)
            model_num += 1
        return loss_store

模型评测

def valid(model, dataloader, criterion):
    loss_sum = 0
    with torch.no_grad():
        for batch_line, batch_category in dataloader:
            batch_pred = model(batch_line)
            loss = criterion(batch_pred, batch_category)
            loss_sum += loss.item()
    return loss_sum / len(dataloader)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值