【chatgpt】使用train_test_split来分割数据集

可以使用 train_test_split 来分割数据集,方便快捷地将数据集分成训练集、验证集和测试集。下面是如何在前面的代码基础上修改,使用 train_test_split 来分割数据集的示例:

代码实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# 定义一个简单的神经网络
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 生成数据集
def create_dataset():
    data = torch.randn(10000, 784)
    labels = torch.randint(0, 10, (10000,))
    return data, labels

# 将数据集分为训练集、验证集和测试集
def split_dataset(data, labels):
    train_data, temp_data, train_labels, temp_labels = train_test_split(data, labels, test_size=0.3, random_state=42)
    val_data, test_data, val_labels, test_labels = train_test_split(temp_data, temp_labels, test_size=0.5, random_state=42)
    return train_data, val_data, test_data, train_labels, val_labels, test_labels

# 训练函数
def train(model, train_loader, criterion, optimizer):
    model.train()
    running_train_loss = 0.0
    for inputs, targets in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
    
    avg_train_loss = running_train_loss / len(train_loader)
    return avg_train_loss

# 测试函数,计算相对误差
def test(model, data_loader):
    model.eval()
    total_relative_error = 0.0
    with torch.no_grad():
        for inputs, targets in data_loader:
            outputs = model(inputs)
            predicted = torch.argmax(outputs, dim=1)
            relative_error = torch.abs(predicted.float() - targets.float()) / torch.abs(targets.float())
            total_relative_error += relative_error.sum().item()
    
    avg_relative_error = total_relative_error / len(data_loader.dataset)
    return avg_relative_error

# 主函数
def main():
    model = SimpleNet()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    data, labels = create_dataset()
    train_data, val_data, test_data, train_labels, val_labels, test_labels = split_dataset(data, labels)

    train_dataset = TensorDataset(train_data, train_labels)
    val_dataset = TensorDataset(val_data, val_labels)
    test_dataset = TensorDataset(test_data, test_labels)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    num_epochs = 20
    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        train_loss = train(model, train_loader, criterion, optimizer)
        val_loss = test(model, val_loader)  # 计算验证集上的相对误差

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Relative Error: {val_loss:.4f}')

    plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
    plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Relative Error')
    plt.xlabel('Epochs')
    plt.ylabel('Loss / Relative Error')
    plt.legend()
    plt.title('Training Loss and Validation Relative Error Curve')
    plt.show()

    # 最终测试集上的评估
    test_relative_error = test(model, test_loader)  # 计算测试集上的相对误差
    print(f'Test Relative Error: {test_relative_error:.4f}')

if __name__ == "__main__":
    main()

代码解释

  1. 生成数据集

    • create_dataset 函数生成一个随机数据集,包括数据和标签。
  2. 分割数据集

    • split_dataset 函数使用 train_test_split 将数据集分割为训练集、验证集和测试集。首先将数据集分割为训练集和临时集,然后再将临时集分割为验证集和测试集。
  3. 训练函数

    • train 函数用于模型训练,返回训练集的平均损失。
  4. 测试函数

    • test 函数计算并返回相对误差,使用测试集和验证集的输入和目标标签计算相对误差。
  5. 主函数

    • main 函数进行模型的训练和验证,并在每个epoch结束时输出训练损失和验证集上的相对误差。最终在测试集上计算相对误差并输出。

通过使用 train_test_split,可以更方便地分割数据集,并确保训练集、验证集和测试集之间的比例符合要求。这样做可以使得数据集分割过程更加简洁明了。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值