神经网络搭建(Pytorch)——train()

神经网络训练的主要步骤如下:

  1. 梯度清零:optimizer.zero_grad()
  2. 将数据喂入设备:inputs, labels = inputs.to(device), labels.to(device)
  3. 前向传播:outputs = model(inputs)
  4. 计算损失函数:loss = criterion(outputs, labels)
  5. 计算梯度:loss.backward()
  6. 更新可训练权重:optimizer.step()

定义 train() 函数:

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def train(tr_loader, val_loader, model, config, device):
	''' Model training '''
	# Define a loss function
	criterion = nn.CrossEntropyLoss()
	# Define an optimizer
	optimizer = getattr(torch.optim, config['optimizer'])(model.parameters(), **config['optim_hparas'])
	
	n_epochs = config['n_epochs']    # Mininum number of epochs
	loss_record = []
	val_acc_record = []
	val_acc = 0.0
	max_val_acc = 0.0
	early_stop_cnt = 0
	
	for epoch in range(n_epochs):
		model.train()
		tr_bar = tpdm(tr_loader)
		for step, data in enumerate(tr_bar):
			inputs, labels = data
			inputs, labels = inputs.to(device), labels.to(device)
			outputs = model(inputs)
			loss = criterion(outputs, labels)
			loss.backward()
			optimier.step()
			
			loss_record.append(loss.item())    # Record train_loss
			
			# Update train_bar
			tr_bar.desc = '[Train|{epoch+1:03d}/{n_epochs:03d} loss={loss:.3f}'

		val_acc = valid(val_loader, model, device)
		if val_acc > max_val_acc:
			max_val_acc = val_acc
			# save model if model improved
			torch.save(model.state_dict(), config['save_path'])
			early_stop_cnt = 0
		else:
			early_stop_cnt += 1    # 记录 acc 没有升高的 epoch 数

		val_acc_record.append(val_acc)
		if early_stop_cnt > config['early_stop']:
			# Stop train if acc difficultly improve
			break

	print('Finish training.')
	return loss_record, val_acc_record


def valid(val_loader, model, device):
	model.eval()    # Set model to evaluation mode
	acc = 0.0
	with torch.no_grad():
		val_bar = tqdm(val_loader)
		for data in val_bar:
			inputs, labels = data
			inputs, labels = inputs.to(device), labels.to(device)
			outputs = model(inputs)
			pred = torch.max(outputs, dim=1)[1]
			acc += (pred.argmax(dim=-1) == label).float().mean()
	return acc
  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值