整体思路为,将每个iter 和每次输出的loss进行保存,然后利用plt进行绘图。
1## 保存loss和iteration
loss = []
iteration=[]
iteration.append(i) #i是你的iter
loss.append(total_loss.item()) #total_loss.item()是你每一次inter输出的loss
2## plt绘图
#num_iter是你总的迭代次数
if i==num_iter-1:
plt.figure()
plt.plot(iteration, loss, label="loss")
plt.draw()
plt.show()
(后序:脑子不好的人只能记录每一个细节,如果还能帮到和我一样的小白,就超级好。嘻嘻?)