pytorch-训练模型通用流程。
模型的训练基本上就是确定两个事,一个是优化器(torch.optim),一个是损失函数(nn.MSELoss, nn.CrossEntropyLoss, nn.NLLoss)。优化器:模型参数。损失函数:需要datasets。这要是为什么损失函数提供类接口原因。
from tqdm import tqdm
import torch
import time
import pickle
def train(epoches, model, dataloader, optim, criterion, model_save_path):
loss_mark = float('inf')
loss_store = []
model_num = 0
for _ in tqdm(range(epoches)):
loss_sum = 0
start = time.time()
for batch_line, batch_category in dataloader:
batch_pred = model(batch_line)
loss = criterion(batch_pred, batch_category)
optim.zero_grad()
loss.backward()
optim.step()
loss_sum += loss.item()
loss_store.append(loss_sum)
if loss_sum < loss_store:
loss_store = loss_sum
torch.save(f'{model_save_path}/{model_num}.pth', pickle_module=pickle, pickle_protocol=2)
model_num += 1
return loss_store
模型评测
def valid(model, dataloader, criterion):
loss_sum = 0
with torch.no_grad():
for batch_line, batch_category in dataloader:
batch_pred = model(batch_line)
loss = criterion(batch_pred, batch_category)
loss_sum += loss.item()
return loss_sum / len(dataloader)