语义分割工具包MMSegmentation
https://github.com/open-mmlab/mmsegmentation
统一超参
模型配置文件 部分说明:
model = dict(
type='EncoderDecoder' OR 'CascadeEncoderDecoder', # 分割模型的主题架构
pretrained='open-mmlab://resnet50_v1c',
backbone=dict( # 主干网络
type='ResNetV1c',
# ... more options),
neck = None, # 颈部
decode_head=dict( # 主解码头
type='PSPHead',
# ... more options
loss_decode=dict( #损失函数
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)
),
auxiliary_head=dict( # 辅助解码头
type='FCNHead',
# ... more options
loss_decode=dict( #损失函数
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4
)
),
train_cfg=dict(...), # 训练和测试配置
test_cfg=dict(...)
)
主干网络的配置
backbone=dict(
type='ResNetV1c', #ResNet v1c 结构
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3), # 输出全部级别的特征图给颈部或者辅助解码头
dilations=(1, 1, 2, 4), # 配置降采样和空洞卷积,增大空洞卷积倍率
strides=(1, 2, 1, 1), # 同时移除采样率
norm_cfg= dict(type='SyncBN', requires_grad=True), # 分割模型会使用 SyncBN
norm_eval=False, # 增大batchSize
style='pytorch',
contract_dilation=True
)
主解码头
decode_head=dict(
type='PSPHead', #使用 PSPNet 的解码头
in_channels=2048,
in_index=3, # 以顶层特征图为输入
channels=512,
pool_scales=(1, 2, 3, 6), #池化金字塔的尺度
dropout_ratio=0.1,
num_classes=19, # 预测类别数
norm_cfg= dict(type='SyncBN', requires_grad=True),
align_corners=False,
loss_decode=dict( #逐像素交叉熵损失函数
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0
)
),
辅助解码头的配置
auxiliary_head=dict(
type='FCNHead', # 使用 FCN 解码头
in_channels=1024,
in_index=2, # 以低层次特征图为输入
channels=256, # FCN 中卷积通道数
num_convs=1, #FCN 中使用 1 层卷积
concat_input=False, #是否拼接输入特征图用于预测
dropout_ratio=0.1,
num_classes=19, # 预测类别数
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict( #逐像素交叉熵损失函数
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4
)
),
数据集配置
dataset_type = 'CityscapesDataset' #数据集类型
data_root = 'data/cityscapes/' #数据集路径
data = dict(
samples_per_gpu=2, #batch size DataLoader 参数
workers_per_gpu=2, #worker 个数 DataLoader 参数
train=dict( #训练集配置
type=dataset_type,
data_root=data_root, #训练集配置
img_dir='leftImg8bit/train',
ann_dir='gtFine/train', #标注文件目录
pipeline=train_pipeline #训练数据的处理流水线
),
val=dict(...), #验证集、训练集的配置
test=dict(...)
)