pytorch绘制loss和accuracy曲线

1.前言

pytorch虽然使用起来很方便,但在一点上并没有tensorflow方便,就是绘制模型训练时在训练集和验证集上的loss和accuracy曲线(共四条)。tensorflow模型训练时,每次epoch的模型,以及在训练集和验证集上的loss和acc都保存在一个对象中,当我们要绘制四条曲线时,直接从对象中取值即可。

2.Loss曲线

Loss_list = []  #存储每次epoch损失值
def draw_loss(Loss_list,epoch):
    # 我这里迭代了200次,所以x的取值范围为(0,200),然后再将每次相对应的准确率以及损失率附在x上
    plt.cla()
    x1 = range(1, epoch+1)
    print(x1)
    y1 = Loss_list
    print(y1)
    plt.title('Train loss vs. epoches', fontsize=20)
    plt.plot(x1, y1, '.-')
    plt.xlabel('epoches', fontsize=20)
    plt.ylabel('Train loss', fontsize=20)
    plt.grid()
    plt.savefig("./lossAndacc/Train_loss.png")
    plt.savefig("./lossAndacc/Train_loss.png")
    plt.show()

3.acc曲线

def draw_fig(list,name,epoch):
    # 我这里迭代了200次,所以x的取值范围为(0,200),然后再将每次相对应的准确率以及损失率附在x上
    x1 = range(1, epoch+1)
    print(x1)
    y1 = list
    if name=="loss":
        plt.cla()
        plt.title('Train loss vs. epoch', fontsize=20)
        plt.plot(x1, y1, '.-')
        plt.xlabel('epoch', fontsize=20)
        plt.ylabel('Train loss', fontsize=20)
        plt.grid()
        plt.savefig("./lossAndacc/Train_loss.png")
        plt.show()
    elif name =="acc":
        plt.cla()
        plt.title('Train accuracy vs. epoch', fontsize=20)
        plt.plot(x1, y1, '.-')
        plt.xlabel('epoch', fontsize=20)
        plt.ylabel('Train accuracy', fontsize=20)
        plt.grid()
        plt.savefig("./lossAndacc/Train _accuracy.png")
        plt.show()

这里我把绘制loss和acc曲线的代码进行了合并。
用法如下,测试模型在验证集上的loss和acc时,让结果返回两个list对象,分别存储了每次epoch时的loss和acc的值。然后调用draw_fig方法,把对象作为参数传递进去。

if __name__ == '__main__':
    # val(model)

    with torch.no_grad():
        criterion = nn.BCEWithLogitsLoss().cuda()
        epoch = 30
        loss=[]
        acc=[]
        for i in range(1, epoch + 1):
            dir = "./result/20201029_2110/checkpoints/" + str(i) + ".pth"
            model = torch.load(dir)
            model.eval()  # 需要加上model.eval(). 否则的话,有输入数据,即使不训练,它也会改变权值
            loss1,acc1=auto_val(model,criterion)
            loss.append(loss1)
            acc.append(acc1)
        draw_fig(loss,"loss",epoch)
        draw_fig(acc,"acc",epoch)
  • 14
    点赞
  • 164
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 20
    评论
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

做个好男人!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值