【OpenMMLab】MMSegmentation 代码

系列文章目录

第一课:【OpenMMLab】OpenMMLab概述
第二课:【OpenMMLab】人体姿态估计、关键点检测与MMPose
第三课:【openMMLab】MMPose 代码教程
第四课:【OpenMMLab】深度学习预训练与 MMPreTrain
第五课: 【OpenMMLab】MMPretrain 代码教程
第六课:【OpenMMLab】目标检测与MMDetection
第七课:【OpenMMLab】MMDetection 代码
第八课:【OpenMMLab】语义分割与MMSegmentation

MMSegmentation 语义分割算法库

MMSegmentation 是专门做图像分割,尤其是语义分割的算法库。在无人驾驶、遥感图像、医疗影响等领域,语义分割都是非常重要的算法。算法库内容非常丰富,包括 600+ 个预训练模型和 40+ 篇算法复现。
子豪兄视频教程地址:https://www.bilibili.com/video/BV1uh411T73q
代码教程地址:https://github.com/TommyZihao/MMSegmentation_Tutorials
算法库地址:https://github.com/open-mmlab/mmsegmentation

安装配置

安装pytorch

# torch 版本 1.10.1 cuda 版本 11.3
pip3 install install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

安装 MMCV

pip install -U openmim
mim install mmengine
mim install "mmcv==2.0.0rc4"

安装其他工具包

包括 cv2: opencv-python, pillow, matplotlib, seaborn, tqdm, pytorch-lightning, mmdet

pip install opencv-python pillow matplotlib seaborn tqdm pytorch-lightning 'mmdet>=3.0.0rc1' -i https://pypi.tuna.tsinghua.edu.cn/simple

下载和安装MMSegmentation

git clone https://github.com/open-mmlab/mmsegmentation.git -b dev-1.x
cd mmsegmentation
pip install -e .

准备工作

  1. 准备文件夹
  • data:存放自己的数据
  • outputs:存放模型的输出结果
  • checkpoint:存放模型权重
mkdir data outputs checkpoint
  1. 下载预训练权重

Model Zoo:https://github.com/open-mmlab/mmsegmentation/blob/master/docs/en/model_zoo.md

# pspnet cityscapes 的模型权重
wget https://download.openmmlab.com/mmsegmentation/v0.5/pspnet/pspnet_r50-d8_512x1024_40k_cityscapes/pspnet_r50-d8_512x1024_40k_cityscapes_20200605_003338-2966598c.pth -P checkpoint
  1. 下载素材
# 伦敦街景图片
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_uk.jpeg -P data

# 上海驾车街景视频,视频来源:https://www.youtube.com/watch?v=ll8TgCZ0plk
!wget https://zihao-download.obs.cn-east-3.myhuaweicloud.com/detectron2/traffic.mp4 -P data

# 街拍视频,2022年3月30日
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20220713-mmdetection/images/street_20220330_174028.mp4 -P data

检查安装

  1. 检查pytorch
import torch
print(torch.__version__)
print(torch.cuda.is_available())
  1. 检查mmcv,cuda 和 编译器版本
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
print("mmcv 版本:", mmcv.__version__)
print("CUDA 版本:", get_compiling_cuda_version())
print("GCC 版本:", get_compiler_version()) 
  1. 检查 mmsegmentation 版本
import mmseg
from mmseg.utils import register_all_modules
from mmseg.apis import inference_model, init_model
print("mmsegmentation 版本": mmseg.__version__)

使用命令行进行预测和可视化

使用 demo/image_demo.py 对图像预测

python demo/image_demo.py \
	img \
	config \
	checkpoint \
	--out-dir outputs \
	--device cuda:0
	--opacity 0.5

使用 demo/video_demo.py 对视频进行预测

python demo/video_demo.py \
	video \
	config \
	checkpoint \
	--device cuda:0 \
	--output-file outputs/b3_video.mp4 \
	--opacity 0.5

使用 Python API 预测

import torch.cuda

from mmseg.apis import init_model, inference_model
from mmseg.datasets import cityscapes
from mmengine.model.utils import revert_sync_batchnorm
from mmengine import ProgressBar
import mmcv
from PIL import Image
import numpy as np
import shutil
import time
import os


def predict_single_frame(model, img, palette, opacity=0.2):
    result = inference_model(model, img)

    seg_map = np.array(result.pred_sem_seg.data[0].detach().cpu().numpy()).astype('uint8')
    seg_map = Image.fromarray(seg_map).convert("P")
    seg_map.putpalette(np.array(palette, dtype="uint8"))

    show_img = np.array(seg_map.convert("RGB")) * (1 - opacity) + img * opacity

    return show_img


def predict_video(model, video_path, palette, temp_dir=None, out_file=None):
    imgs = mmcv.VideoReader(video_path)
    pgb = ProgressBar(len(imgs))
    out_dir = os.path.dirname(out_file)
    if temp_dir is None:
        temp_dir = os.path.join(out_dir, "temp_dir")
        assert not os.path.exists(temp_dir), "please clarity the temp_dir"
        os.makedirs(temp_dir)

    for frame_id, img in enumerate(imgs):
        show_img = predict_single_frame(model, img, palette)
        out_frame = f"{temp_dir}/{frame_id:06d}.jpg"
        mmcv.imwrite(show_img, out_frame)

        pgb.update()

    mmcv.frames2video(temp_dir, out_file, fps=imgs.fps, fourcc="mp4v")

    shutil.rmtree(temp_dir)
    print(f"删除临时文件夹{temp_dir}")


if __name__ == '__main__':
    config_path = os.environ["CONFIG"]
    checkpoint = os.environ["CKPT"]
    video_path = os.environ["VIDEO"]

    model = init_model(config_path, checkpoint, device="cuda:0")
    if not torch.cuda.is_available():
        model = revert_sync_batchnorm(model)

    # 从cityscapes 获取类别和类别可视化颜色
    classes = cityscapes.CityscapesDataset.METAINFO["classes"]
    palette = cityscapes.CityscapesDataset.METAINFO["palette"]

    time_ = time.strftime("%Y%m%d%H%M%S")
    predict_video(model, video_path, palette=palette, out_file=os.path.join("outputs", time_, "out_video.mp4"))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值