MMSegmentation代码课

本文详细介绍了如何安装配置MMSegmentation库,包括安装PyTorch和其他依赖项,下载并处理西瓜像素级语义分割数据,以及数据的可视化。接着,文章展示了如何自定义配置文件以适应特定任务,训练模型,并进行模型预测。最后,对模型在测试集上的性能进行了评估。
摘要由CSDN通过智能技术生成

一、安装配置MMSegmentation

安装pytorch:

# 安装Pytorch
!pip3 install install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html

安装mmengine和mmcv依赖

# 安装 mmengine 和 mmcv 依赖
# 为了防止后续版本变更导致的代码无法运行,我们暂时锁死版本
!pwd
%pip install -U "openmim==0.3.7"
!mim install "mmengine==0.7.1"
!mim install "mmcv==2.0.0"

!pip install opencv-python pillow matplotlib seaborn tqdm pytorch-lightning 'mmdet>=3.0.0rc1'
# 从 github 上下载最新的 mmsegmentation 源代码
!git clone https://github.com/open-mmlab/mmsegmentation.git -b dev-1.x
# 进入主目录
%cd mmsegmentation


!pip install -v -e .

查看环境是否安装成功

from mmengine.utils import get_git_hash
from mmengine.utils.dl_utils import collect_env as collect_base_env

import mmseg

# 环境信息收集和打印
def collect_env():
    """Collect the information of the running environments."""
    env_info = collect_base_env()
    env_info['mmseg'] = f'{mmseg.__version__}+{get_git_hash()[:7]}'
    return env_info


if __name__ == '__main__':
    for name, val in collect_env().items():
        print(f'{name}: {val}')

二、数据准备和可视化

这里准备的是西瓜瓤、西瓜皮、西瓜籽像素级语义分割

下载数据:


import os

# 创建 checkpoint 文件夹,用于存放预训练模型权重文件
os.mkdir('checkpoint')

# 创建 outputs 文件夹,用于存放预测结果
os.mkdir('outputs')

# 创建 data 文件夹,用于存放图片和视频素材
os.mkdir('data')
%cd ./data
! wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Labelme.zip
! wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/dataset/watermelon/Watermelon87_Semantic_Seg_Mask.zip
! unzip Watermelon87_Semantic_Seg_Labelme.zip
! unzip Watermelon87_Semantic_Seg_Mask.zip

数据可视化:

import os

import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm

import matplotlib.pyplot as plt
%matplotlib inline

# 指定图像和标注路径
PATH_IMAGE = 'Watermelon87_Semantic_Seg_Mask/img_dir/train'
PATH_MASKS = 'Watermelon87_Semantic_Seg_Mask/ann_dir/train'
# n行n列可视化
n = 5

# 标注区域透明度
opacity = 0.5

fig, axes = plt.subplots(nrows=n, ncols=n, sharex=True, figsize=(12,12))

for i, file_name in enumerate(os.listdir(PATH_IMAGE)[:n**2]):

    # 载入图像和标注
    img_path = os.path.join(PATH_IMAGE, file_name)
    mask_path = os.path.join(PATH_MASKS, file_name.split('.')[0]+'.png')
    img = cv2.imread(img_path)
    mask = cv2.imread(mask_path)

    # 可视化
    axes[i//n, i%n].imshow(img)
    axes[i//n, i%n].imshow(mask[:,:,0], alpha=opacity)
    axes[i//n, i%n].axis('off') # 关闭坐标轴显示
fig.suptitle('Image and Semantic Label', fontsize=30)
plt.tight_layout()
plt.show()

  • / 背景 / 0
  • red 西瓜红瓤 多段线(polygon) 1
  • green 西瓜外壳 多段线(polygon) 2
  • white 西瓜白皮 多段线(polygon) 3
  • seed-black 西瓜黑籽 多段线(polygon) 4
  • seed-white 西瓜白籽 多段线(polygon) 5

 三、自定义配置文件

%cd /content/mmsegmentation/

!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/DubaiDataset.py -P mmseg/datasets

!rm -rf mmseg/datasets/__init__.py # 删除原有文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/__init__.py -P mmseg/datasets

!rm -rf configs/_base_/datasets/DubaiDataset_pipeline.py
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/DubaiDataset_pipeline.py -P configs/_base_/datasets

!rm -rf configs/pspnet/pspnet_r50-d8_4xb2-40k_DubaiDataset.py # 删除原有文件
!wget https://zihao-openmmlab.obs.cn-east-3.myhuaweicloud.com/20230130-mmseg/Dubai/pspnet_r50-d8_4xb2-40k_DubaiDataset.py -P configs/pspn

修改mmseg/datasets/DubaiDataset.py 中的classes,

修改configs/pspn/pspnet_r50-d8_4xb2-40k_DubaiDataset.py  中的dataset_root 为数据存放地址;

生成自己的配置文件:

from mmengine import Config
cfg = Config.fromfile('./configs/pspn/pspnet_r50-d8_4xb2-40k_DubaiDataset.py')

cfg.norm_cfg = dict(type='BN', requires_grad=True) # 只使用GPU时,BN取代SyncBN
cfg.crop_size = (256, 256)
cfg.model.data_preprocessor.size = cfg.crop_size
cfg.model.backbone.norm_cfg = cfg.norm_cfg
cfg.model.decode_head.norm_cfg = cfg.norm_cfg
cfg.model.auxiliary_head.norm_cfg = cfg.norm_cfg
# modify num classes of the model in decode/auxiliary head

# 模型 decode/auxiliary 输出头,指定为类别个数
cfg.model.decode_head.num_classes = 6
cfg.model.auxiliary_head.num_classes = 6

cfg.train_dataloader.batch_size = 8

cfg.test_dataloader = cfg.val_dataloader

# 结果保存目录
cfg.work_dir = './work_dirs/DubaiDataset'

# 训练迭代次数
cfg.train_cfg.max_iters = 3000
# 评估模型间隔
cfg.train_cfg.val_interval = 400
# 日志记录间隔
cfg.default_hooks.logger.interval = 100
# 模型权重保存间隔
cfg.default_hooks.checkpoint.interval = 1500

# 随机数种子
cfg['randomness'] = dict(seed=0)
# cfg.data_root = '/content/mmsegmentation/data/Watermelon87_Semantic_Seg_Mask'

print(cfg.pretty_text)

cfg.dump('pspnet-DubaiDataset_20230612.py')

四、模型训练

import numpy as np

import os.path as osp
from tqdm import tqdm

import mmcv
import mmengine
from mmengine import Config
cfg = Config.fromfile('pspnet-DubaiDataset_20230612.py')
from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner
register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)
runner.train()

五、模型预测

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from mmseg.apis import init_model, inference_model, show_result_pyplot
import mmcv
import cv2

# 载入 config 配置文件
from mmengine import Config
cfg = Config.fromfile('pspnet-DubaiDataset_20230612.py')

from mmengine.runner import Runner
from mmseg.utils import register_all_modules

# register all modules in mmseg into the registries
# do not init the default scope here because it will be init in the runner

register_all_modules(init_default_scope=False)
runner = Runner.from_cfg(cfg)

checkpoint_path = './work_dirs/DubaiDataset/iter_3000.pth'
model = init_model(cfg, checkpoint_path, 'cuda:0')

img = mmcv.imread('/content/mmsegmentation/data/Watermelon87_Semantic_Seg_Mask/img_dir/val/L007-05_5.jpg')
result = inference_model(model, img)
result.keys()
pred_mask = result.pred_sem_seg.data[0].cpu().numpy()
pred_mask.shape
np.unique(pred_mask)
plt.imshow(pred_mask)
plt.show()

# 可视化预测结果
visualization = show_result_pyplot(model, img, result, opacity=0.7, out_file='pred.jpg')
plt.imshow(mmcv.bgr2rgb(visualization))
plt.show()

# 可视化标签
label = mmcv.imread('/content/mmsegmentation/data/Watermelon87_Semantic_Seg_Mask/ann_dir/val/L007-05_5.png')
label.shape
label_mask = label[:,:,0]
label_mask.shape
np.unique(label_mask)

plt.imshow(label_mask)
plt.show()

六、测试集性能评估

!python tools/test.py pspnet-DubaiDataset_20230612.py work_dirs/DubaiDataset/iter_3000.pth

 | Class | IoU | Acc |

| bk | 72.1 | 96.71 |

| red | 52.61 | 55.4 |

| green | 28.49 | 32.65 |

| white | 41.2 | 42.24 |

| seed-black | 59.86 | 64.06 |

| seed-white | 0.74 | 0.74

aAcc: 77.2500 mIoU: 42.5000 mAcc: 48.6300 data_time: 0.0037 time: 4.6348

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值