Pytorch与深度学习自查手册4-训练、可视化、日志输出、保存模型

Pytorch与深度学习自查手册4-训练、可视化、日志输出、保存模型

训练和验证(包含可视化、日志、保存模型)

初始化模型、dataloader都完善以后,正式进入训练部分。

训练部分包括:

  1. 及时的日志记录

  2. tensorboard可视化log

  3. 输入

  4. 前向传播

  5. loss计算

  6. 反向传播

  7. 权重更新

  8. 固定步骤进行验证

  9. 最佳模型的保存(+bad case输出)

日志记录

利用logging模块在控制台实时打印并及时记录运行日志。

from config import  *
import logging  # 引入logging模块
import os.path
class Logger:
    def __init__(self,mode='w'):
        # 第一步,创建一个logger
        self.logger = logging.getLogger()
        self.logger.setLevel(logging.INFO)  # Log等级总开关
        # 第二步,创建一个handler,用于写入日志文件
        rq = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        log_path = os.getcwd() + '/Logs/'
        log_name = log_path + rq + '.log'
        logfile = log_name
        fh = logging.FileHandler(logfile, mode=mode)
        fh.setLevel(logging.DEBUG)  # 输出到file的log等级的开关
        # 第三步,定义handler的输出格式
        formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
        fh.setFormatter(formatter)
        # 第四步,将logger添加到handler里面
        self.logger.addHandler(fh)
        ch = logging.StreamHandler()
        ch.setLevel(logging.INFO)  # 输出到console的log等级的开关
        ch.setFormatter(formatter)
        self.logger.addHandler(ch)

完整的训练流程

import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler
import sys
from tqdm import tqdm
import torch

def train_one_epoch(model, optimizer, data_loader, device, epoch):
    model.train()
    loss_function = torch.nn.CrossEntropyLoss()
    mean_loss = torch.zeros(1).to(device)
    mean_acc = torch.zeros(1).to(device)
    optimizer.zero_grad()

    data_loader = tqdm(data_loader)
    for iteration, data in enumerate(data_loader):
        batch, labels = data
        pred = model(batch.to(device))

        loss = loss_function(pred, labels.to(device))
        loss.backward()
        mean_loss = (mean_loss * iteration + loss.detach()) / (step + 1)  # update mean losses
        pred = torch.max(pred, dim=1)[1]
        iter_acc=torch.eq(pred, labels.to(device)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值