def train_one_epoch(model, optimizer, data_loader, device, epoch):
model.train()
loss_function = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
accu_loss = torch.zeros(1).to(device) # 累计损失 tensor([0.])
accu_num = torch.zeros(1).to(device) # 累计预测正确的样本数
optimizer.zero_grad()
sample_num = 0
data_loader = tqdm(data_loader, file=sys.stdout) # 循环遍历数据加载器:通过tqdm包装的data_loader进行循环,以便在训练过程中显示进度条。对于每批数据
for step, data in enumerate(data_loader):
images, labels = data
sample_num += images.shape[0]
'''images.shape[0]用于获取当前批次(batch)中图像的数量。
在PyTorch中,一个张量(Tensor)的.shape属性返回一个元组,
描述了张量在每个维度上的大小。对于图像数据,这通常遵循格式
(批次大小, 通道数, 高度, 宽度):
'''
pred = model(images.to(device))
pred_classes = torch.max(pred, dim=1)[1]
accu_num += torch.eq(pred_classes, labels.to(device)).sum()
loss = loss_function(pred, labels.to(device))
loss.backward()
accu_loss += loss.detach()
data_loader.desc = "[train epoch {}] loss: {:.3f}, acc: {:.3f}".format(epoch,
accu_loss.item() / (step + 1),
accu_num.item() / sample_num)
if not torch.isfinite(loss): # 检查损失的有效性:如果在某一步损失变成非有限值(比如NaN或无穷大),训练将停止,并显示警告信息。
print('WARNING: non-finite loss, ending training ', loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return accu_loss.item() / (step + 1), accu_num.item() / sample_num
02-24
11-06
1166
09-27
4994
07-11
3796
03-10
2554
05-01
2916
“相关推荐”对你有帮助么?
-
非常没帮助
-
没帮助
-
一般
-
有帮助
-
非常有帮助
提交