Pspnet模型训练私有数据集

1.准备工作

论文:PSPNET

代码:github

环境准备:

PyTorch>=1.1.0, Python3, tensorboardX
注:pytorch为gpu版本

2.论文介绍

PSPNet(Pyramid Scene Parsing Network)是一种用于图像语义分割的深度卷积神经网络模型。它于2016年由中国科学院自动化研究所提出,旨在解决图像语义分割中的像素级别分类问题。

PSPNet的核心思想是利用金字塔池化模块来捕捉不同尺度上的上下文信息,以提高对图像语义的理解和分割准确性。该模型的主要特点包括:

1.金字塔池化模块(Pyramid Pooling Module):该模块通过在不同尺度上进行池化操作,从不同层次上捕捉图像的全局和局部上下文信息。它能够有效地扩展感受野,使网络能够对不同尺度的对象和场景进行细粒度的分割。

2.ResNet作为主干网络:PSPNet通常使用ResNet作为主干网络,以提取图像特征。ResNet是一种深度残差网络,通过残差连接(residual connection)解决了深层网络训练中的梯度消失问题,有助于提高网络的收敛性和性能。

3.融合和上采样:在池化模块之后,PSPNet通过级联融合和上采样操作,将来自不同尺度的特征图进行融合,并将其上采样到原始图像的尺寸。融合后的特征图可以更准确地表示图像中的不同语义区域。

PSPNet在图像语义分割任务中取得了良好的性能,并在多个公开数据集上进行了验证和比较。它被广泛应用于许多计算机视觉任务,如场景理解、遥感图像分析和自动驾驶等领域。

PSPNet基于FCN(Fully Convolutional Network)ps:图像语义分割的深度卷积神经网络模型上进行改进的。

论文中还有一个细节是辅助损失(auxiliary loss),在resnet101的res4b22层引出一条FCN分支,用于计算辅助损失。论文里设置了赋值损失loss2的权重为0.4。

 该模型在VOC数据集上获得优秀表现:

3.模型训练 

1.git clone https://github.com/hszhao/semseg.git
2.cd semseg
3.sh tool/train.sh voc2012 pspnet50  

3.1数据集准备:VOC2012格式

特别说明:ImageSets/Segmentation目录下的txt文件内容为:

xxx.png    xxx.png
xxx.png    xxx.png
.
.
.
.

txt生成代码:

import os
import random

# 数据集路径
dataset_path = "/home/suxd/workspace/semseg/dataset/VOC2012/JPEGImages"

# 标签集路径
label_path = "/home/suxd/workspace/semseg/dataset/VOC2012/SegmentationClassAug"

# 输出文件路径
output_path = "/home/suxd/workspace/semseg/dataset/VOC2012/ImageSets/Segmentation"

# 确保输出文件夹存在
if not os.path.exists(output_path):
    os.makedirs(output_path)

# 获取所有PNG图片文件名
dataset = [file for file in os.listdir(dataset_path) if file.endswith('.png')]

# 确定划分比例
train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1

# 打乱数据集
random.shuffle(dataset)

# 计算每个集合的大小
train_size = int(len(dataset) * train_ratio)
val_size = int(len(dataset) * val_ratio)
test_size = len(dataset) - train_size - val_size

# 划分数据集
train_set = dataset[:train_size]
val_set = dataset[train_size:train_size + val_size]
test_set = dataset[train_size + val_size:]

# 训练验证集包含训练集和验证集
train_val_set = train_set + val_set

# 生成txt文件
def save_to_txt(data, filename, dataset_dir, label_dir):
    with open(os.path.join(output_path, filename), 'w') as file:
        for item in data:
            # 获取图像文件的上一级目录路径
            image_rel_path = os.path.join(dataset_dir, item)
            # 获取对应的标签文件名
            label_filename = item.replace('JPEGImages', 'SegmentationClassAug').replace('.jpg', '.png')
            # 获取标签文件的上一级目录路径
            label_rel_path = os.path.join(label_dir, label_filename)
            # 写入文件名和扩展名
            file.write("%s %s\n" % (image_rel_path, label_rel_path))

# 保存训练集、验证集、测试集和训练验证集
save_to_txt(train_set, "train.txt", "JPEGImages", "SegmentationClassAug")
save_to_txt(val_set, "val.txt", "JPEGImages", "SegmentationClassAug")
save_to_txt(test_set, "test.txt", "JPEGImages", "SegmentationClassAug")
save_to_txt(train_val_set, "train_val.txt", "JPEGImages", "SegmentationClassAug")

3.2配置文件修改

.../semseg/config/voc2012/voc2012_pspnet50.yaml

 文件中路径全部修改为绝对路径;类别数根据自己数据集修改;train_h和train_w要求其数值减1,能被8整除;257-1=256/8=32;

evaluate:True  #训练过程进行验证

DATA:
  data_root: /home/suxd/workspace/semseg/dataset/VOC2012
  train_list: /home/suxd/workspace/semseg/dataset/VOC2012/ImageSets/Segmentation/train.txt
  val_list: /home/suxd/workspace/semseg/dataset/VOC2012/ImageSets/Segmentation/val.txt
  classes: 2

TRAIN:
  arch: psp
  layers: 50
  sync_bn: False  # adopt sync_bn or not
  train_h: 257
  train_w: 257
  scale_min: 0.5  # minimum random scale
  scale_max: 2.0  # maximum random scale
  rotate_min: -10  # minimum random rotate
  rotate_max: 10  # maximum random rotate
  zoom_factor: 8  # zoom factor for final prediction during training, be in [1, 2, 4, 8]
  ignore_label: 255
  aux_weight: 0.4
  train_gpu: [0]
  workers: 16  # data loader workers
  batch_size: 16  # batch size for training
  batch_size_val: 8  # batc
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Git_SUXD

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值