train.py 的代码基本结构。
- 读取与加载数据
从硬盘读取数据
train_data = XXXDataset(data_dir=train_dir, transform=train_transform)
valid_data = XXXDataset(data_dir=valid_dir, transform=valid_transform)
加载数据到网络
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)
- 加载模型、初始化权重
net = LeNet(classes=2)
net.initialize_weights()
- 选择损失函数
criterion = nn.CrossEntropyLoss()
- 优化器
# 选择优化器
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 设置学习率更新策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.1)
- 迭代训练
for epoch in range(MAX_EPOCH):
loss_mean = 0.
correct = 0.
total = 0.
net.train()
for i, data in enumerate(train_loader):
# forward
...
# backward
...
# update weights
...
# 统计分类情况、准确率、IoU等
...
# 打印训练信息
...
scheduler.step() # 更新学习率