pytorch matplotlib实时可视化训练过程

在《keras matplotlib实时可视化训练过程》这篇文章中,我们已经实现了keras训练神经网络的时候,实时观察曲线的变化,那么在pytorch框架下该如何实现呢?也很简单,直接上程序:

# -*- coding: utf-8 -*-
# TODO: LQD 2019/10/24
# TODO: qq:743701947
import torch
import matplotlib.pyplot as plt
import threading

_loss_save = []
_trainAcc_save = []
_testAcc_save = []
flag_plot = True


def _thread_plot_all():
    global _loss_save, _trainAcc_save, _testAcc_save, flag_plot
    fig = plt.figure('acc---------loss')
    ax1 = fig.add_subplot(1, 2, 1)
    ax2 = fig.add_subplot(1, 2, 2)
    ax1.set_title('acc')
    ax1.set_xlabel('epoch')
    ax1.set_ylabel('acc')
    ax2.set_title('loss')
    ax2.set_xlabel('epoch')
    ax2.set_ylabel('loss')
    plt.ion()
    for i in range(100000):
        if flag_plot == True:
            try:
                ax1.lines.remove(lines1[0])
                ax1.lines.remove(lines2[0])
                ax2.lines.remove(lines3[0])
            except Exception as e:
                pass
            lines1 = ax1.plot(_trainAcc_save, 'r-', label='acc')
            lines2 = ax1.plot(_testAcc_save, 'b-', label='val_acc')
            ax1.legend()
            lines3 = ax2.plot(_loss_save, 'r-', label='loss')
            ax2.legend()
            plt.pause(0.5)
    plt.ioff()
    # plt.show()


class MyThread(threading.Thread):
    def __init__(self, threadID, name, func):
        threading.Thread.__init__(self)
        self.threadID = threadID
        self.name = name
        self.func = func

    def run(self):
        print('thread {} is started!'.format(self.threadID))
        self.func()


def visualization_of_deep_learning_training():
    t1 = MyThread(1, 'Thread-1', _thread_plot_all)
    t1.start()
    plt.show()

上面的程序放到总程序的最开始即可,在pytoch的fit里面,要稍作修改(需要增加的地方用need标记了):

def t_main(is_plot_result=False):
    global _loss_save, _trainAcc_save, _testAcc_save, flag_plot # need
    model = CnnNet()
    if torch.cuda.is_available():
        model = model.cuda()

    Loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adamax(model.parameters(), lr=learning_rate)
    nn.utils.clip_grad_norm_(filter(lambda p: p.requires_grad, model.parameters()), max_norm=5)

    _loss_save = []
    _trainAcc_save = []
    _testAcc_save = []
    for epoch in range(epoches):
        print('epoch {}'.format(epoch + 1))
        train_loss = 0.
        train_acc = 0.
        for i, data in enumerate(dloader_train):
            img, label = data
            label = label.squeeze(1)
            label = label.long()

            img, label = img.to(device), label.to(device)
            # print('step= {0}, img= {1}'.format(i+1, img.size()))

            optimizer.zero_grad()
            out = model(img)
            loss = Loss(out, label)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            pre = torch.max(out, 1)[1]
            train_correct = (pre.cpu() == label.cpu()).sum()
            train_acc += train_correct.item()
            if i == len(dloader_train) - 1:
                _loss_save.append(train_loss / len(dloader_train)) # need
                _trainAcc_save.append((train_acc / len(dloader_train)) / 100) # need
                print('train loss= ', train_loss / len(dloader_train))
                print('train acc= ', train_acc / len(dloader_train))
                train_loss = 0.
                train_acc = 0.

        with torch.no_grad():
            correct = 0
            total = 0
            for i, test_data in enumerate(dloader_test):
                test_img, test_label = test_data
                test_label = test_label.squeeze(1)
                test_label = test_label.long()

                images, labels = test_img.to(device), test_label.to(device)
                outputs = model(images)

                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        _testAcc_save.append((100 * correct / total) / 100) # need
        # torch.save(model.state_dict(), '_model.pkl')
        print('第%d个epoch的识别准确率为:%f' % (epoch + 1, (100 * correct / total)))
    flag_plot = False # need

    if is_plot_result == True:
        plt.figure('train loss')
        plt.plot(_loss_save)
        plt.xlabel('Epochs')
        plt.ylabel('Train Loss')
        plt.savefig('TrainLoss.png')
        plt.figure('train acc')
        plt.plot(_trainAcc_save)
        plt.xlabel('Epochs')
        plt.ylabel('Train Acc')
        plt.savefig('TrainAcc.png')
        plt.figure('test acc')
        plt.plot(_testAcc_save)
        plt.xlabel('Epochs')
        plt.ylabel('Test Acc')
        plt.savefig('TestAcc.png')

在最后,我们可以直接在if main 里面调用即可:

if __name__ == '__main__':
    visualization_of_deep_learning_training()
    t_main(is_plot_result=False)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值