一、Segformer简介
在SegFormer
提出时,transformer
已经开始在图像领域展露头角。在此之前,SETR最早将transformer
结构引入到了语义分割任务中。SETR
采用ViT
作为backbone
,并结合多个CNN
decoder
来放大特征分辨率。但是SETR
具有两个局限性:
没有有效利用multi-scale
特征;
具有非常高的计算消耗。
为了解决上述问题,pyramid vision Transformer (PVT)
被提出。PVT
具有金字塔结构,使得分割结果有进一步的提升。但是包括PVT
、Swin
、Twins
等新兴方法都是在改进encoder
,但是忽略了decoder
的改进。
与以前的方法相比,SegFormer
同时考虑了效率、准确性和鲁棒性,作者重新设计了encoder
和decoder
,主要创新点包括:
一种新型的无位置编码(position-encoding-free)
分层变压器编码器;
一种轻量级的 All-MLP
(多层感知机) decoder
设计;
在三个公开数据集上达到了SOTA
的效果。
下面为Segformer
论文中,Segformer
分别与SETR
,DeepLabv3+
的效果对比图。
二、Segformer代码下载
代码可以在我的资源下载处获取。
三、CityScapes数据集简介
Cityscapes
数据集共有fine
和coarse
两套评测标准,前者提供5000张精细标注的图像,后者提供5000张精细标注外加20000张粗糙标注的图像。Cityscapes
数据集包含2975张图片。包含了街景图片和对应的标签。大小为113MB。Cityscapes
数据集,包含戴姆勒在内的三家德国单位联合提供,包含50多个城市的立体视觉数据。
四 CityScapes数据集下载
4.1 登录网站https://www.cityscapes-dataset.com/downloads/ 并使用邮箱注册账号
4.2 下载下图框出的两部分数据集:
五、 CityScapes数据标签格式转换
## 5.1 使用命令行
网站,下载zip压缩包方式,下载git clone https://githubfast.com/mcordts/cityscapesScripts.git
或者登录https://github.com/mcordts/cityscapesScriptsCityScapes
数据集解析代码。
下载的结果如下图:
5.2 下载和使用CityScapes工具
pip install cityscapesscripts -i https://pypi.tuna.tsinghua.edu.cn/simple
结果如下:
修改5.1.步骤中下载的代码cityscapesScripts/cityscapesscripts/preparation/createTrainIdLabelImgs.py
在 def main():
前添加一行os.environ[‘CITYSCAPES_DATASET’] = “你的CityScapes gtFine路径”
确保能找到你下载的数据集的标签路径。
运行cityscapesScripts/cityscapesscripts/preparation/createTrainIdLabelImgs.py
文件,进行标签数据格式转换:
在gt_Fine
的train val
文件夹下,增加了一些以_polygons.json" , "_labelTrainIds.png
结尾的文件,即为转换生成的数据、
对应createTrainIdLabelImgs.py
中的代码为:
现在,我们有了两个文件夹,一个是leftImg8bit
的原始图像文件夹,一个是gtFine
标注文件夹。
现在,我们要将这两个文件夹里面的图像都提取出来,存入train、val、test
文件夹中。
运行下面的代码,即可将原始图像提取并处理。
import os
import random
import shutil
# 数据集路径
dataset_path = r"dataset/cityscapes/leftImg8bit_trainvaltest/leftImg8bit"
#原始的train, valid文件夹路径
train_dataset_path = os.path.join(dataset_path,'train')
val_dataset_path = os.path.join(dataset_path,'val')
test_dataset_path = os.path.join(dataset_path,'test')
#创建train,valid的文件夹
train_images_path = os.path.join(dataset_path,'cityscapes_train')
val_images_path = os.path.join(dataset_path,'cityscapes_val')
test_images_path = os.path.join(dataset_path,'cityscapes_test')
if os.path.exists(train_images_path)==False:
os.mkdir(train_images_path )
if os.path.exists(val_images_path)==False:
os.mkdir(val_images_path)
if os.path.exists(test_images_path)==False:
os.mkdir(test_images_path)
#-----------------移动文件夹-------------------------------------------------
for file_name in os.listdir(train_dataset_path):
file_path = os.path.join(train_dataset_path,file_name)
for image in os.listdir(file_path):
shutil.copy(os.path.join(file_path,image), os.path.join(train_images_path,image))
for file_name in os.listdir(val_dataset_path):
file_path = os.path.join(val_dataset_path,file_name)
for image in os.listdir(file_path):
shutil.copy(os.path.join(file_path,image), os.path.join(val_images_path,image))
for file_name in os.listdir(test_dataset_path):
file_path = os.path.join(test_dataset_path,file_name)
for image in os.listdir(file_path):
shutil.copy(os.path.join(file_path,image), os.path.join(test_images_path,image))
运行后生成如下文件夹。
对于label文件也同样如此,比如下面生成19类的标注文件夹。
import os
import random
import shutil
# 数据集路径
dataset_path = r"dataset\cityscapes\gtFine_trainvaltest\gtFine"
#原始的train, valid文件夹路径
train_dataset_path = os.path.join(dataset_path,'train')
val_dataset_path = os.path.join(dataset_path,'val')
test_dataset_path = os.path.join(dataset_path,'test')
#创建train,valid的文件夹
train_images_path = os.path.join(dataset_path,'cityscapes_19classes_train')
val_images_path = os.path.join(dataset_path,'cityscapes_19classes_val')
test_images_path = os.path.join(dataset_path,'cityscapes_19classes_test')
if os.path.exists(train_images_path)==False:
os.mkdir(train_images_path )
if os.path.exists(val_images_path)==False:
os.mkdir(val_images_path)
if os.path.exists(test_images_path)==False:
os.mkdir(test_images_path)
#-----------------移动文件---对于19类语义分割, 主需要原始图像中的labelIds结尾图片-----------------------
for file_name in os.listdir(train_dataset_path):
file_path = os.path.join(train_dataset_path,file_name)
for image in os.listdir(file_path):
#查找对应的后缀名,然后保存到文件中
if image.split('.png')[0][-13:] == "labelTrainIds":
#print(image)
shutil.copy(os.path.join(file_path,image), os.path.join(train_images_path,image))
for file_name in os.listdir(val_dataset_path):
file_path = os.path.join(val_dataset_path,file_name)
for image in os.listdir(file_path):
if image.split('.png')[0][-13:] == "labelTrainIds":
shutil.copy(os.path.join(file_path,image), os.path.join(val_images_path,image))
for file_name in os.listdir(test_dataset_path):
file_path = os.path.join(test_dataset_path,file_name)
for image in os.listdir(file_path):
if image.split('.png')[0][-13:] == "labelTrainIds":
shutil.copy(os.path.join(file_path,image), os.path.join(test_images_path,image))
得到如下结果。
到这里,我们已经提取了所有的图像文件和标注文件。
六、训练
6.1 修改数据集路径
6.2 修改预训练权重路径
6.3 运行train.py文件
七、测试
运行predict.py
文件,我这里使用的是官方训练好的best1-model_city_b2_79_801.pth
模型权重,测试效果如下: