2021/8/17
通过使用tqdm,实时显示训练进度,并显示当前训练集正确率以及损失
效果图如下:
实现代码:
def train(model, criterion, optimizer, trainloader, Epoch, EPOCHS, BATCH_SIZE):
model.train()
loop = tqdm(enumerate(trainloader), total =len(trainloader))
running_loss = 0.0
right = 0
for step, (batch_x, batch_y) in loop:
batch_x, batch_y = batch_x.cuda(), batch_y.cuda()
output = model(batch_x)
optimizer.zero_grad()
loss = criterion(output, batch_y)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(output.data, 1)
# 累加识别正确的样本数
right += (predicted == batch_y).sum()
#更新信息
loop.set_description(f'Epoch [{Epoch}/{EPOCHS}]')
loop.set_postfix(loss=running_loss/(step+1), acc=float(right)/float(BATCH_SIZE*step+len(batch_x)))
参考:https://zhuanlan.zhihu.com/p/378474516