1.准备工作
论文:PSPNET
代码:github
环境准备:
PyTorch>=1.1.0, Python3, tensorboardX
注:pytorch为gpu版本
2.论文介绍
PSPNet(Pyramid Scene Parsing Network)是一种用于图像语义分割的深度卷积神经网络模型。它于2016年由中国科学院自动化研究所提出,旨在解决图像语义分割中的像素级别分类问题。
PSPNet的核心思想是利用金字塔池化模块来捕捉不同尺度上的上下文信息,以提高对图像语义的理解和分割准确性。该模型的主要特点包括:
1.金字塔池化模块(Pyramid Pooling Module):该模块通过在不同尺度上进行池化操作,从不同层次上捕捉图像的全局和局部上下文信息。它能够有效地扩展感受野,使网络能够对不同尺度的对象和场景进行细粒度的分割。
2.ResNet作为主干网络:PSPNet通常使用ResNet作为主干网络,以提取图像特征。ResNet是一种深度残差网络,通过残差连接(residual connection)解决了深层网络训练中的梯度消失问题,有助于提高网络的收敛性和性能。
3.融合和上采样:在池化模块之后,PSPNet通过级联融合和上采样操作,将来自不同尺度的特征图进行融合,并将其上采样到原始图像的尺寸。融合后的特征图可以更准确地表示图像中的不同语义区域。
PSPNet在图像语义分割任务中取得了良好的性能,并在多个公开数据集上进行了验证和比较。它被广泛应用于许多计算机视觉任务,如场景理解、遥感图像分析和自动驾驶等领域。
PSPNet基于FCN(Fully Convolutional Network)ps:图像语义分割的深度卷积神经网络模型上进行改进的。
论文中还有一个细节是辅助损失(auxiliary loss),在resnet101的res4b22层引出一条FCN分支,用于计算辅助损失。论文里设置了赋值损失loss2的权重为0.4。
该模型在VOC数据集上获得优秀表现:
3.模型训练
1.git clone https://github.com/hszhao/semseg.git
2.cd semseg
3.sh tool/train.sh voc2012 pspnet50
3.1数据集准备:VOC2012格式
特别说明:ImageSets/Segmentation目录下的txt文件内容为:
xxx.png xxx.png
xxx.png xxx.png
.
.
.
.
txt生成代码:
import os
import random
# 数据集路径
dataset_path = "/home/suxd/workspace/semseg/dataset/VOC2012/JPEGImages"
# 标签集路径
label_path = "/home/suxd/workspace/semseg/dataset/VOC2012/SegmentationClassAug"
# 输出文件路径
output_path = "/home/suxd/workspace/semseg/dataset/VOC2012/ImageSets/Segmentation"
# 确保输出文件夹存在
if not os.path.exists(output_path):
os.makedirs(output_path)
# 获取所有PNG图片文件名
dataset = [file for file in os.listdir(dataset_path) if file.endswith('.png')]
# 确定划分比例
train_ratio = 0.7
val_ratio = 0.2
test_ratio = 0.1
# 打乱数据集
random.shuffle(dataset)
# 计算每个集合的大小
train_size = int(len(dataset) * train_ratio)
val_size = int(len(dataset) * val_ratio)
test_size = len(dataset) - train_size - val_size
# 划分数据集
train_set = dataset[:train_size]
val_set = dataset[train_size:train_size + val_size]
test_set = dataset[train_size + val_size:]
# 训练验证集包含训练集和验证集
train_val_set = train_set + val_set
# 生成txt文件
def save_to_txt(data, filename, dataset_dir, label_dir):
with open(os.path.join(output_path, filename), 'w') as file:
for item in data:
# 获取图像文件的上一级目录路径
image_rel_path = os.path.join(dataset_dir, item)
# 获取对应的标签文件名
label_filename = item.replace('JPEGImages', 'SegmentationClassAug').replace('.jpg', '.png')
# 获取标签文件的上一级目录路径
label_rel_path = os.path.join(label_dir, label_filename)
# 写入文件名和扩展名
file.write("%s %s\n" % (image_rel_path, label_rel_path))
# 保存训练集、验证集、测试集和训练验证集
save_to_txt(train_set, "train.txt", "JPEGImages", "SegmentationClassAug")
save_to_txt(val_set, "val.txt", "JPEGImages", "SegmentationClassAug")
save_to_txt(test_set, "test.txt", "JPEGImages", "SegmentationClassAug")
save_to_txt(train_val_set, "train_val.txt", "JPEGImages", "SegmentationClassAug")
3.2配置文件修改
.../semseg/config/voc2012/voc2012_pspnet50.yaml
文件中路径全部修改为绝对路径;类别数根据自己数据集修改;train_h和train_w要求其数值减1,能被8整除;257-1=256/8=32;
evaluate:True #训练过程进行验证
DATA:
data_root: /home/suxd/workspace/semseg/dataset/VOC2012
train_list: /home/suxd/workspace/semseg/dataset/VOC2012/ImageSets/Segmentation/train.txt
val_list: /home/suxd/workspace/semseg/dataset/VOC2012/ImageSets/Segmentation/val.txt
classes: 2
TRAIN:
arch: psp
layers: 50
sync_bn: False # adopt sync_bn or not
train_h: 257
train_w: 257
scale_min: 0.5 # minimum random scale
scale_max: 2.0 # maximum random scale
rotate_min: -10 # minimum random rotate
rotate_max: 10 # maximum random rotate
zoom_factor: 8 # zoom factor for final prediction during training, be in [1, 2, 4, 8]
ignore_label: 255
aux_weight: 0.4
train_gpu: [0]
workers: 16 # data loader workers
batch_size: 16 # batch size for training
batch_size_val: 8 # batc