【Pytorch】训练模型

一、训练完整流程

使用Pytorch训练神经网络的一般流程为(伪代码,许多功能需要自己实现,这里只列出了流程):

import torch
import torch.optim as optim
import numpy as np
from tqdm import tqdm
from tensorboardX import SummaryWriter
from torch.optim.lr_scheduler import StepLR

def train():
    with torch.cuda.device(gpu_id):
        ## model
        model = Model() # 定义模型
        model = model.cuda()

        # optimizer & lr_scheduler
        optimizer = torch.optim.SGD(model.parameters(), lr=init_lr,
                                    momentum=momentum, weight_decay=weight_decay)

        lr_scheduler = StepLR(optimizer, step_size=25, gamma=0.8) # 定义学习率
        # lr_scheduler = lr_decay()  # 也可以是自己定义的学习率下降方式,比如定义了一个列表

        if resume:  # restore from checkpoint
            model, optimizer = restore_from(model, optimizer, ckpt) # 恢复训练状态

        # load train data
        trainloader, validloader = dataset() #自己定义DataLoader

        ### logs
        logger = create_logger()  # 自己定义创建的log日志
        summary_writer = SummaryWriter(log_dir) # tensorboard


        ### start train
        for epoch in range(end_epoch):
            scheduler.step() # 更新optimizer的学习率,一般以epoch为单位,即多少个epoch后换一次学习率

            train_loss = []
            model.train()
            model = model.cuda()

            ## train
            for i, data in enumerate(tqdm(trainloader)):
                input, target = data
                optimizer.zero_grad() #使用之前先清零
                output = model(input.cuda())
                loss = Loss(output, target)  # 自己定义损失函数

                loss.backward() # loss反传,计算模型中各tensor的梯度
                optimizer.step() #用在每一个mini batch中,只有用了optimizer.step(),模型才会更新
                train_loss.append(loss)
            train_loss = np.mean(train_loss) # 对各个mini batch的loss求平均

            ## eval,不需要梯度反传
            valid_loss = []
            
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值