最详细的语义分割---04PSPNet的训练

**

实例化模型

**
由于我们前面已经把相应的模块都已经准备好了,我们在这一部分只需要把他们导入过来,并对相应的超参数进行赋值即可。这里device是对设备的类型进行判断,若存在GPU,我们则使用GPU加速训练。
在这里插入图片描述
实例化我们之前写好的dataloder和网络。拿到网络的优化器,和学习率调整方式、损失函数。这里需要注意的是,voc数据集的背景类,即标签中黑色的部分,它的标签是255,所以我们在这里计算预测结果和真实标签的时候忽略背景。关于如何忽略背景,只要在交叉熵函数中的ignore_idex=255即可。
在这里插入图片描述

train_epoch

定义训练一个epoch的函数,这个函数传入的参数是我们当前训练到那个epoch了,特别要注意的一点是,在进行训练的时候,第一步一定要把优化器的梯度清零,不然可能会造成梯度爆炸。每行代码的意思,我都进行了注释,这里就不一一解释了。
在这里插入图片描述

这里会调用eval_metrics方法计算当前训练的性能指标,它的输入是网络预测输出,标签和分类类别总数。
具体的实现方式可以参见我的博客:https://blog.csdn.net/weixin_47142735/article/details/115792241?spm=1001.2014.3001.5501。
这个函数会返回一个列表,其内容如下[分类正确的像素总数,该批次像素总数,[每个类别预测图于标签相交的像素总数],[每个类别预测图标签图相并的像素总数]]。其中前两个元素是两个数,后面的两个元素是一个列表,列表中的每个元素是每类像素相交的元素总数。这里我们会把每次计算出来的Iou和PA保存起来,方便显示当前模型的性能指标。

在这里插入图片描述

val_epoch

验证函数的作用就是监视网络训练,避免网络训练过拟合。它的编写思路整体于训练脚本类似,只不过整个过程不需要反传梯度。我们需要注意的是,验证时一定要把模型设置为验证模式。
在这里插入图片描述

权值保存

我们训练网络的目的就是为了找到最合适的权值,但是我们没有必要把每个权值都保存下来,我么只需要在验证的时候按条件保存权值即可。很多时候,训练网络都是一个漫长的过程,总会出现无法一次训练结束的情况,所以这个时候也许要保存权值,方便下次的继续训练。
这一小段代码就可以实现上面的功能,当我们保存权值的时候以字典的形式记录epoch,权值字典和优化器参数,这样在下次训练的时候就可以接着之前的epoch和学习率接着训练了。
在这里插入图片描述

整个训练

前面的tian和val都是只训练一个epoch,即只轮流训练了一个数据集中所有图片一次,这是远远不够的,所以,我们还要在这个基础上进行大的循环,多训练记次网络,同时按照条件进行验证。
在这里插入图片描述

整个代码

import torch
from tqdm import tqdm # 进度条显示
from torch.utils.data import DataLoader
from dataloader import VOCDataset
from model import PSPNet
from  helper import eval_metrics
import numpy as np
import os

# 判断能否使用gpu加速运算
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
voc_root = r'D:\1Apython\Pycharm_pojie\data_set\VOCdevkit'  # 填写voc数据集的路径
save_dir = r'./weights'  # 权重存储位置
EPOCH = 200  # 总的训练次数
num_classes = 21   # voc数据集的类别总数
batch_size = 4   # 数据集的btch_size大小
pre_val = 2    # 多少次训练验证和保存权重一次
crop_size = 284  # 裁剪大小

# 实例化 daloader
train_datasets = VOCDataset(root=voc_root,split='train',num_classes=num_classes,base_size=300,crop_size=crop_size)
val_datasets = VOCDataset(root=voc_root,split='val',num_classes=num_classes,base_size=300,crop_size=crop_size)
train_dataloader = DataLoader(train_datasets,batch_size=batch_size,num_workers=1,shuffle=True,drop_last=True)
val_dataloader = DataLoader(train_datasets,batch_size=batch_size,num_workers=1,shuffle=True,drop_last=True)

model = PSPNet(num_classes=num_classes,pretrained=True)  #实例化PSPNet
#实例化优化器
optimizer = torch.optim.SGD(lr=0.005,params=model.parameters())
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=1,gamma=0.9)
# 实例化学习率更新策略,可以根据自己的需求选择不同的调整方法,这里随便使用了一个StepLR
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=255)
# 实例化损失函数,voc数据集背景标签为255,所以我们计算交叉熵的时候忽略背景
def get_lr(optimizer): # 拿到变化的学习率
    for param_group in optimizer.param_groups:
        #print(param_group['lr'])
        return param_group['lr']
# 训练函数
def train_epoch(epoch):
    total_loss = 0  # 保存当前epoch的损失
    total_inter, total_union = 0, 0 # 批次图像的交集、并集
    total_correct, total_label = 0, 0  # 批次图像所有预测正确的像素点、批次图像所有的像素点
    model.to(device)
    model.train()  # 将网络设置为训练模式
    tbar = tqdm(train_dataloader, ncols=130)   # 封装显示模块
    for index,(image,label) in enumerate(tbar):
        image = image.to(device)  # 搬运到GPU上进行训练
        label = label.to(device)
        output = model(image)   # 拿到模型的预测结果.

        assert output[0].size()[2:] == label.size()[1:]  #检查结果
        assert output[0].size()[1] == num_classes
        loss = loss_fn(output[0], label)  # 主干网络损失
        loss += loss_fn(output[1], label) * 0.4  # 辅助网络损失
        output = output[0]   #记录主干网络的预测结果,后面计算性能指标使用
        loss.backward()    # 反传梯度
        optimizer.step()   # 梯度更新
        optimizer.zero_grad()  # 优化器梯度清零
        lr_scheduler.step(epoch=epoch - 1) # 学习率更新

        lr = get_lr(optimizer)  # 拿到当前学习率
        total_loss += loss.item()   # 保存损失

        seg_metrics = eval_metrics(output, label, num_classes)  # 计算每批次PA和miou
        #返回一个列表[计算正确的像素总数,像素总数,标签与预测图相交部分,标签与预测图相并部分(每个类别)]

        correct, num_labeled,inter, union = seg_metrics  # 对seg_metircs进行解包
        "将该epoch中所有正确的像素总数、所有像素总数、交集、和并集累加起来"
        total_correct += correct  # 更新批次图像计算正确的像素
        total_label += num_labeled  # 更新总的像素值
        total_inter += inter     # 更新相交区域的值
        total_union += union     # 更新相并部分的值

        # 计算平均值
        "这里计算的PA和mIoU是将一个epoch中每个batch的交并比进行累加,然后计算平均交并比"
        pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)  #计算PA=正确分类像素总数/像素总数
        IoU = 1.0 * total_inter / (np.spacing(1) + total_union) # 计算Iou = 相交部分/相并部分  np.spacing(1)防止分母为0的情况
        mIoU = IoU.mean() # 计算类别的平均IoU

        # 显示打印信息
        tbar.set_description(
            'TRAIN {}/{} | Loss: {:.3f}| Acc {:.2f} mIoU {:.2f}  | lr {:.8f}|'.format(
                epoch,EPOCH, np.round(total_loss/(index+1),3),np.round(pixAcc,3),
                np.round(mIoU,3),lr))
    lr_scheduler.step()  # 学习率更新

# 验证函数
def val_epoch(epoch):
    total_loss = 0   # 保存验证的总损失
    total_inter, total_union = 0, 0
    total_correct, total_label = 0, 0

    model.to(device)
    model.eval()                   # 开启验证模式
    print(f'正在使用 {device} 进行验证! ')
    tbar = tqdm(val_dataloader,ncols=130)  # 设置进度条信息
    with torch.no_grad():   # 关闭梯度信息
        for index,(image,label) in enumerate(tbar):
            image = image.to(device)       # 搬运到GPU上进行预测
            label = label.to(device)
            output = model(image)          # 传入模型获得预测结果
            loss = loss_fn(output,label)   # 计算验证的时候的损失
            total_loss  += loss.item()      # 累计loss

            seg_metrics = eval_metrics(output, label, num_classes)  # 计算每批次PA和miou
            correct, num_labeled, inter, union = seg_metrics  # 对seg_metircs进行解包
            "将该epoch中所有正确的像素总数、所有像素总数、交集、和并集累加起来"
            total_correct += correct  # 更新批次图像计算正确的像素
            total_label += num_labeled  # 更新总的像素值
            total_inter += inter  # 更新相交区域的值
            total_union += union  # 更新相并部分的值

            # 计算平均值
            "这里计算的PA和mIoU是将一个epoch中每个batch的交并比进行累加,然后计算平均交并比"
            pixAcc = 1.0 * total_correct / (np.spacing(1) + total_label)  # 计算PA=正确分类像素总数/像素总数
            IoU = 1.0 * total_inter / (np.spacing(1) + total_union)  # 计算Iou = 相交部分/相并部分  np.spacing(1)防止分母为0的情况
            mIoU = IoU.mean()  # 计算类别的平均IoU
            # 显示当前的预测信息
            tbar.set_description('EVAL ({})|Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f}|'.format(epoch,
                                                        total_loss/(index+1),(pixAcc), mIoU))
        print('Finish validation!') # 显示所有验证图片的平均信息
        print(f'total loss:{np.round(total_loss/(index+1),3)} || PA:{np.round(pixAcc,3)} || mIoU:{np.round(mIoU,3)}')
        print(f'every class Iou {dict(zip(range(num_classes), np.round(IoU,3)))}')

        print('正在保存权重!!!!')
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        filename = os.path.join(save_dir, f'checkpoint--epoch{epoch}.pth')
        torch.save(state, filename)
        print(f'成功保存第{epoch}epoch权重文件')

# 总训练函数
def train(EPOCH):
    print(f'正在使用 {device} 进行训练! ')
    for i in range(EPOCH):
        train_epoch(i)   # 调用上面的train_epoch 进行一轮训练
        if i % pre_val == 0:  # 按照条件进行验证
            val_epoch(i)

if __name__ == '__main__':
    train(EPOCH)

如果需要整个文件的代码,可以到下面这个网盘链接下载
链接:https://pan.baidu.com/s/19-CMyQvzIduxGeVwtVN7cQ
提取码:fbkd
在这里插入图片描述

  • 23
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
PSPNet是一种用于语义分割深度学习网络模型,它通过使用金字塔池化模块来整合基于不同区域的上下文信息,从而提供了有效的全局上下文先验。相比于其他最先进的方法,PSPNet在效果上表现更好。\[1\] 金字塔池化模块可以收集具有层级的信息,比全局池化更有代表性。同时,PSPNet的计算量并没有比原来的空洞卷积FCN网络有很大的增加。在端到端学习中,全局金字塔池化模块和局部FCN特征可以被同时训练。\[2\] PSPNet在语义分割任务中具有优越的性能。它利用金字塔池化模块和金字塔场景解析网络来聚合不同区域的全局上下文信息,从而生成高质量的场景解析结果。该方法在不同的数据集上实现了最先进的性能,例如在PASCAL VOC 2012和Cityscapes数据集上的mIoU准确性分别为85.4%和80.2%。\[3\] #### 引用[.reference_title] - *1* [PSPNet | 语义分割及场景分析](https://blog.csdn.net/qq_42722197/article/details/125611648)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [语义分割-PSPNet](https://blog.csdn.net/weixin_43925119/article/details/109706219)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insert_down28v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值