211024-Pytorch模型训练中显示进度条

# 🌟 1. 定义进度条
# 🌟 2. 设置迭代器
# 🌟 3. 设置开头
# 🌟 4. 设置结尾

# ⚠️:tqdm会在一定程度上降低训练的速度
# ⚠️:如果需要观察变化,需要使用time.sleep

from tqdm import tqdm
import time
    # ! with tqdm
    def fit(self, args, model, device, train_loader, optimizer, epoch, vis=True):
        # print('Epochs: ', epoch)
        correct = 0
        L = len(train_loader.dataset)
        model.train()
        with tqdm(train_loader, unit="batch") as tepoch: # 🌟 1. 定义进度条
            for data, target in tepoch:              # 🌟 2. 设置迭代器
                tepoch.set_description(f"Epoch {epoch}") # 🌟 3. 设置开头
                data, target = data.to(device), target.to(device)#.squeeze(1)       # Data to device
                optimizer.zero_grad()                                              # Zero gradient 
                output = model(data)                                               # Forward propagation
                losstr = F.nll_loss(output,target)                                 # Calculate loss
                losstr.backward()                                                  # Back propagation
                optimizer.step()                                                   # Optimize parameters
        
                predict = output.argmax(dim=1, keepdim=True)                           # Get the index 
                correct = predict.eq(target.view_as(predict)).sum().item() 
            
                model.global_step += 1
                # model.vis_moni_train(losstr.item(),correct,model.global_step) if vis==True else None # Visulize
                accuracy = correct/len(data)
                tepoch.set_postfix(loss=losstr.item(), accuracy='{:.3f}'.format(accuracy)) # 🌟 4. 设置结尾
                sleep(0.0001)

在这里插入图片描述

    # ! withour tqdm
    def fit(self, args, model, device, train_loader, optimizer, epoch, vis=True):
        print('Epochs: ', epoch)
        correct = 0
        L = len(train_loader.dataset)
        model.train()
        for data, target in tqdm(train_loader):              # Load from cpu
            data, target = data.to(device), target.to(device)#.squeeze(1)       # Data to device
            optimizer.zero_grad()                                              # Zero gradient 
            output = model(data)                                               # Forward propagation
            losstr = F.nll_loss(output,target)                                 # Calculate loss
            losstr.backward()                                                  # Back propagation
            optimizer.step()                                                   # Optimize parameters
    
            predict = output.argmax(dim=1, keepdim=True)                           # Get the index 
            correct = predict.eq(target.view_as(predict)).sum().item() 
        
            model.global_step += 1
            model.vis_moni_train(losstr.item(),correct,model.global_step) if vis==True else None # Visulize
            accuracy = correct/len(data)
  • https://towardsdatascience.com/training-models-with-a-progress-a-bar-2b664de3e13e
  • https://blog.csdn.net/dreaming_coder/article/details/113486645
  • 8
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

GuokLiu

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值