首先,我们需要导入必要的库和模块:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
接下来,我们定义一些超参数:
# 超参数
batch_size = 32
learning_rate = 0.001
num_epochs = 10
然后,我们定义数据增强和数据加载器:
# 数据增强
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 数据加载器
train_set = ImageFolder('train', transform=transform)
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_set = Ima