语义分割代码实战教学
HRNet 高分辨率神经网络
安装配置
# 选择分支
git branch -a
git switch 3.x
# 配置环境
conda create -n mmsegmentation python=3.8
conda activate mmsegmentation
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install mmcv==2.0.0rc3 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11/index.html
pip install -U openmim
mim install mmengine
pip install -v -e .
# 下载预训练模型
wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth
预训练语义分割模型预测图片
通过脚本文件,利用预训练模型进行预测
python demo/image_demo.py \
data/street_uk.jpeg \
configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
--out-file outputs/B1_uk_pspnet.jpg \
--device cuda:0 \
--opacity 0.5
--opacity
的作用是调节透明度,更像原图或者更像语义分割后的图
通过编写api来提取信息
from mmseg.apis import init_model
model = init_model(config_file, checkpoint_file, device='cuda:0')
result = inference_model(model, img_path)
result
result.keys()
>>> ['pred_sem_seg', 'seg_logits']
# result.pred_sem_seg中语义分割图为单通道图,每个值为0-18,即共19各类别
result.pred_sem_seg.data.shape
>>> torch.Size([1, 1500, 2250])
# 一共多少类别
np.unique(result.pred_sem_seg.data.cpu())
>>> array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 13, 15])
# result.seg_logits是置信度,每一个像素属于预测类别的置信度
result.seg_logits.data.shape
>>> torch.Size([19, 1500, 2250])
预训练语义分割模型预测视频
python demo/video_demo.py \
data/traffic.mp4 \
configs/pspnet/pspnet_r50-d8_4xb2-40k_cityscapes-512x1024.py \
https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth \
--device cuda:0 \
--output-file outputs/B3_video.mp4 \
--opacity 0.5
在自己数据集上训练语义分割模型
下载数据集
wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/iccv09Data.tar.gz -O stanford_background.tar.gz
修改数据集类
from mmseg.registry import DATASETS
from mmseg.datasets import BaseSegDataset
@DATASETS.register_module()
class StanfordBackgroundDataset(BaseSegDataset):
METAINFO = dict(classes = classes, palette = palette)
def __init__(self, **kwargs):
super().__init__(img_suffix='.jpg', seg_map_suffix='.png', **kwargs)
修改config
配置文件
- 修改model.head.num_classes
- 修改数据集的data_type和data_root
- 指定训练集的路径和测试集的路径
- 指定预训练模型权重文件路径
- 修改训练配置参数,训练epoch,batch_size等