pytorch神经网络学习笔记04----包含训练集、验证集、测试集的过程代码

该代码示例展示了一个使用PyTorch实现的深度学习模型,包括两个全连接层。模型训练通过遍历训练集、计算损失和更新参数进行。验证和测试函数计算验证集和测试集的损失及准确率。在main函数中,定义超参数、数据加载器,并进行模型训练、验证和测试。
摘要由CSDN通过智能技术生成

该代码示例包含了训练、测试和验证过程的实现。首先,定义一个叫做Model的模型,它包含两个全连接层。然后定义了train函数,它遍历训练集,计算损失并更新模型参数。在同样的方式下,实现validate和test函数,它们分别计算验证集和测试集的损失和准确率。

在main函数中,定义了一系列超参数和设备,然后将训练集、验证集和测试集加载进来。接着,创建一个新的Model对象,定义损失函数和优化器,并开始训练和验证过程。最后,调用test函数进行测试。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# 定义模型
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(100, 50)
        self.fc2 = nn.Linear(50, 10)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 训练
def train(model, optimizer, criterion, train_loader, device):
    total_loss = 0
    
    # 将模型置为训练模式
    model.train()
    
    # 遍历训练集
    for data, target in train_loader:
        # 将数据移到设备上
        data, target = data.to(device), target.to(device)
        # 清零梯度
        optimizer.zero_grad()
        # 前向传播
        output = model(data)
        # 计算损失
        loss = criterion(output, target)
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()
        # 统计损失
        total_loss += loss.item()
    
    # 计算平均损失
    avg_loss = total_loss / len(train_loader)
    
    return avg_loss

# 验证
def validate(model, criterion, val_loader, device):
    total_loss = 0
    
    # 将模型置为验证模式
    model.eval()
    
    with torch.no_grad():
        # 遍历验证集
        for data, target in val_loader:
            # 将数据移到设备上
            data, target = data.to(device), target.to(device)
            # 前向传播
            output = model(data)
            # 计算损失
            loss = criterion(output, target)
            # 统计损失
            total_loss += loss.item()

    # 计算平均损失
    avg_loss = total_loss / len(val_loader)
    
    return avg_loss

# 测试
def test(model, criterion, test_loader, device):
    total_loss = 0
    total_correct = 0
    
    # 将模型置为测试模式
    model.eval()
    
    with torch.no_grad():
        # 遍历测试集
        for data, target in test_loader:
            # 将数据移到设备上
            data, target = data.to(device), target.to(device)
            # 前向传播
            output = model(data)
            # 计算损失
            loss = criterion(output, target)
            # 统计损失
            total_loss += loss.item()
            # 计算准确率
            _, predicted = torch.max(output.data, 1)
            total_correct += (predicted == target).sum().item()

    # 计算平均损失
    avg_loss = total_loss / len(test_loader)
    # 计算准确率
    accuracy = total_correct / len(test_loader.dataset)
    
    return avg_loss, accuracy

# 主函数
def main():
    # 定义超参数
    epochs = 10
    lr = 0.01
    batch_size = 32
    
    # 定义设备
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 定义训练集、验证集、测试集
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    # 定义模型
    model = Model().to(device)
    
    # 定义损失函数和优化器
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    
    # 训练和验证
    for epoch in range(1, epochs + 1):
        train_loss = train(model, optimizer, criterion, train_loader, device)
        val_loss = validate(model, criterion, val_loader, device)
        
        print(f"Epoch: {epoch}, Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}")
    
    # 测试
    test_loss, test_accuracy = test(model, criterion, test_loader, device)
    print(f"Test loss: {test_loss:.4f}, Test accuracy: {test_accuracy:.4f}")

if __name__ == '__main__':
    main()

这是一个基本框架,使用这个框架,就可以完成很多使用pytorch的项目。

  • 11
    点赞
  • 89
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值