这是一个包含训练集、测试集和验证集的 PyTorch 模型的代码实现。这里使用 CIFAR-10 数据集作为示例数据集。
import torch
import torchvision
import torchvision.transforms as transforms
# 定义图像变换
transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集数据
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
# 定义训练集数据加载器
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
# 加载测试集数据
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
# 定义测试集数据加载器
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
shuffle=False, num_workers=2)
# 划分验证集数据
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(trainset, [train_size, val_size])
# 定义验证集数据加载器
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=128,
shuffle=False, num_workers=2)
在上面的代码中,torchvision 库中的 transforms 模块用于定义用于数据预处理的转换,例如将数据缩放到 [0, 1] 或者标准化数据。使用 Compose 函数将多个转换串联起来。
在这段代码中,我们加载 CIFAR-10 数据集并定义了图像变换,然后将数据加载到训练集、测试集和验证集。为了划分出验证集,在加载训练集时,我们使用 PyTorch 的 DataLoader 类和 random_split 函数将数据划分成训练集和验证集。
加载数据后,我们可以像这样访问这些数据集:
# 访问训练集
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 执行训练步骤
# 访问测试集
for i, data in enumerate(testloader, 0):
inputs, labels = data
# 执行测试步骤
# 访问验证集
for i, data in enumerate(valloader, 0):
inputs, labels = data
# 执行验证步骤