神经网络搭建(Pytorch)——train()

神经网络训练的主要步骤如下:

  1. 梯度清零:optimizer.zero_grad()
  2. 将数据喂入设备:inputs, labels = inputs.to(device), labels.to(device)
  3. 前向传播:outputs = model(inputs)
  4. 计算损失函数:loss = criterion(outputs, labels)
  5. 计算梯度:loss.backward()
  6. 更新可训练权重:optimizer.step()

定义 train() 函数:

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

def train(tr_loader, val_loader, model, config, device):
	''' Model training '''
	# Define a loss function
	criterion = nn.CrossEntropyLoss()
	# Define an optimizer
	optimizer = getattr(torch.optim, config['optimizer'])(model.parameters(), **config['optim_hparas'])
	
	n_epochs = config['n_epochs']    # Mininum number of epochs
	loss_record = []
	val_acc_record = []
	val_acc = 0.0
	max_val_acc = 0.0
	early_stop_cnt = 0
	
	for epoch in range(n_epochs):
		model.train()
		tr_bar = tpdm(tr_loader)
		for step, data in enumerate(tr_bar):
			inputs, labels = data
			inputs, labels = inputs.to(device), labels.to(device)
			outputs = model(inputs)
			loss = criterion(outputs, labels)
			loss.backward()
			optimier.step()
			
			loss_record.append(loss.item())    # Record train_loss
			
			# Update train_bar
			tr_bar.desc = '[Train|{epoch+1:03d}/{n_epochs:03d} loss={loss:.3f}'

		val_acc = valid(val_loader, model, device)
		if val_acc > max_val_acc:
			max_val_acc = val_acc
			# save model if model improved
			torch.save(model.state_dict(), config['save_path'])
			early_stop_cnt = 0
		else:
			early_stop_cnt += 1    # 记录 acc 没有升高的 epoch 数

		val_acc_record.append(val_acc)
		if early_stop_cnt > config['early_stop']:
			# Stop train if acc difficultly improve
			break

	print('Finish training.')
	return loss_record, val_acc_record


def valid(val_loader, model, device):
	model.eval()    # Set model to evaluation mode
	acc = 0.0
	with torch.no_grad():
		val_bar = tqdm(val_loader)
		for data in val_bar:
			inputs, labels = data
			inputs, labels = inputs.to(device), labels.to(device)
			outputs = model(inputs)
			pred = torch.max(outputs, dim=1)[1]
			acc += (pred.argmax(dim=-1) == label).float().mean()
	return acc
  • 6
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch 是一个基于 Python 的科学计算库,它支持动态计算图,使得神经网络搭建和训练变得非常方便。以下是 PyTorch 搭建神经网络的基本步骤: 1. 导入 PyTorch 库和其他必要的库,如 numpy。 ```python import torch import torch.nn as nn import numpy as np ``` 2. 定义神经网络的结构,可以使用 PyTorch 提供的各种层(如全连接层、卷积层等)来搭建网络。 ```python class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, 10) def forward(self, x): x = x.view(-1, 784) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x ``` 上面的代码定义了一个三层的全连接神经网络,输入为 784 维,输出为 10 维。 3. 定义损失函数和优化器。 ```python criterion = nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) ``` 上面的代码使用交叉熵作为损失函数,使用随机梯度下降(SGD)作为优化器。 4. 训练神经网络。 ```python for epoch in range(num_epochs): for i, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = net(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() if (i+1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item())) ``` 上面的代码使用 PyTorch 提供的 DataLoader 加载数据,然后进行迭代训练。在每次迭代中,需要将数据和标签转换为 Tensor,并将其发送到 GPU 上(如果有的话),然后进行正向传播、计算损失、反向传播和更新模型参数。在训练过程中,可以打印出损失值和当前的训练进度。 5. 测试神经网络。 ```python with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total)) ``` 上面的代码使用测试集来评估神经网络的准确率。在测试时,需要关闭梯度计算,以节省内存和时间。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值