深度学习Pytorch借助ProgBar实现进度条打印训练过程信息

tensorflow中的keras提供了Progbar进度条工具,通过它就可以实现和tensorflow一样的打印训练过程的进度条。
代码如下:

import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import tensorflow as tf
创建模型
model = Model()
model.to(device)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
打印训练一次的函数
def train_one_epoch(epoch, epochs, model, train_loader, test_loader):
    num_training_samples = len(train_loader.dataset)
    batch_size = train_loader.batch_size
    metrics_names = ['loss' ,'acc', 'val_loss', 'val_acc']
    
    print("\nepoch {}/{}".format(epoch+1,epochs))
    progBar = tf.keras.utils.Progbar(num_training_samples//batch_size, stateful_metrics=metrics_names) 

    # 训练模式
    correct = 0
    total = 0
    sum_loss = 0
    train_step = 0
    
    model.train()
    for x, y in train_loader:
        train_step += 1
        x, y = x.to(device), y.to(device)
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        with torch.no_grad():
            y_ = torch.argmax(y_pred, dim=1)
            correct += (y_ == y).sum().item()
            total += y.size(0)
            sum_loss += loss.item()
            running_loss = sum_loss / total
            running_acc = correct / total
            
        values = [('loss',running_loss),('acc',running_acc)]
        progBar.update(train_step-1, values=values)
    
    epoch_loss = sum_loss / total
    epoch_acc = correct / total
    
    # 测试模式
    test_correct = 0
    test_total = 0
    test_sum_loss = 0
    test_step = 0
    
    model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            test_step += 1
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_ = torch.argmax(y_pred, dim=1)
            test_correct += (y_ == y).sum().item()
            test_total += y.size(0)
            test_sum_loss += loss.item()
    
    test_epoch_loss = test_sum_loss / test_total
    test_epoch_acc = test_correct / test_total

    values = [('val_loss',test_epoch_loss),('val_acc',test_epoch_acc)]
    progBar.update(train_step, values=values, finalize=True) 
    return epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc

def fit(epochs=10):
    train_loss = []
    train_acc = []
    test_loss = []
    test_acc = []
    
    for epoch in range(epochs):
        epoch_loss, epoch_acc, test_epoch_loss, test_epoch_acc = train_one_epoch(epoch, epochs, model, train_dl, test_dl)
        train_loss.append(epoch_loss)
        train_acc.append(epoch_acc) 
        test_loss.append(test_epoch_loss)
        test_acc.append(test_epoch_acc)

    history = {'loss':train_loss,'acc':train_acc,'val_loss':test_loss,'val_acc':test_acc}
    return history
训练多次
history = fit(epochs=10)
epoch 1/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0080 - acc: 0.8148 - val_loss: 0.0015 - val_acc: 0.8552

epoch 2/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0051 - acc: 0.8806 - val_loss: 0.0013 - val_acc: 0.8810

epoch 3/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0043 - acc: 0.8989 - val_loss: 0.0012 - val_acc: 0.8948

epoch 4/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0038 - acc: 0.9105 - val_loss: 0.0011 - val_acc: 0.9021

epoch 5/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0034 - acc: 0.9202 - val_loss: 0.0011 - val_acc: 0.8988

epoch 6/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0030 - acc: 0.9269 - val_loss: 0.0011 - val_acc: 0.9039

epoch 7/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0027 - acc: 0.9356 - val_loss: 9.5237e-04 - val_acc: 0.9148

epoch 8/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0024 - acc: 0.9432 - val_loss: 0.0012 - val_acc: 0.8895

epoch 9/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0022 - acc: 0.9480 - val_loss: 9.7795e-04 - val_acc: 0.9140

epoch 10/10
937/937 [==============================] - 8s 9ms/step - loss: 0.0019 - acc: 0.9543 - val_loss: 0.0010 - val_acc: 0.9156
画图
pd.DataFrame(history).plot()

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值