话不多说,直接上代码。
train_losses=[]
val_accuracies = []
train_losses.append(running_loss / train_steps)
val_accuracies.append(val_accurate)
plt.plot(np.arange(1, epochs+1), train_losses, label="train_loss")
plt.plot(np.arange(1, epochs + 1), val_accuracies, label="val_accuracy")
plt.xlabel("Epochs")
plt.title("Train_losses and val_accuracies")
plt.legend()
plt.savefig('./result')
plt.show()
具体怎么应用我们找个代码示例说着更明白。
train_losses=[] #创建两个存储train_loss和val_accuracy的列表
val_accuracies = []
# epochs = 10
# best_acc = 0.0
# save_path = './{}Net.pth'.format(model_name)
# train_steps = len(train_loader)
# for epoch in range(epochs):
# # train
# net.train()
# running_loss = 0.0
# train_bar = tqdm(train_loader, file=sys.stdout)
# for step, data in enumerate(train_bar):
# images, labels = data
# optimizer.zero_grad()
# outputs = net(images.to(device))
# loss = loss_function(outputs, labels.to(device))
# loss.backward()
# optimizer.step()
#
# # print statistics
# running_loss += loss.item()
#
# train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
# epochs,
# loss)
# validate
# net.eval()
# acc = 0.0 # accumulate accurate number / epoch
# with torch.no_grad():
# val_bar = tqdm(val_loader, file=sys.stdout)
# for val_data in val_bar:
# val_images, val_labels = val_data
# outputs = net(val_images.to(device))
# predict_y = torch.max(outputs, dim=1)[1]
# acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
#
# val_accurate = acc / val_num
# print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
# (epoch + 1, running_loss / train_steps, val_accurate))
val_accuracies.append(val_accurate) #添加每次的val_accuracy到列表
train_losses.append(running_loss / train_steps)#添加每次训练的loss
# if val_accurate > best_acc:
# best_acc = val_accurate
# torch.save(net.state_dict(), save_path)
# 训练结束后绘制损失和准确率随着 epoch 变化的图像
plt.plot(np.arange(1, epochs+1), train_losses, label="train_loss")#绘制train_loss图像
plt.plot(np.arange(1, epochs + 1), val_accuracies, label="val_accuracy")#绘制val_accuracy图像
plt.xlabel("Epochs")#x轴对应的标签
plt.title("Train_losses and val_accuracies")#图像的标题
plt.legend()#图例
plt.savefig('./result')#图像保存路径及标题
plt.show()#展示,一定要先保存再plt.show,否则保存的是空白的图像。